feat(dc/tool-c): 完成AI代码生成服务(Day 3 MVP)

核心功能:
- 新增AICodeService(550行):AI代码生成核心服务
- 新增AIController(257行):4个API端点
- 新增dc_tool_c_ai_history表:存储对话历史
- 实现自我修正机制:最多3次智能重试
- 集成LLMFactory:复用通用能力层
- 10个Few-shot示例:覆盖Level 1-4场景

技术优化:
- 修复NaN序列化问题(Python端转None)
- 修复数据传递问题(从Session获取真实数据)
- 优化System Prompt(明确环境信息)
- 调整Few-shot示例(移除import语句)

测试结果:
- 通过率:9/11(81.8%) 达到MVP标准
- 成功场景:缺失值处理、编码、分箱、BMI、筛选、填补、统计、分类
- 待优化:数值清洗、智能去重(已记录技术债务TD-C-006)

API端点:
- POST /api/v1/dc/tool-c/ai/generate(生成代码)
- POST /api/v1/dc/tool-c/ai/execute(执行代码)
- POST /api/v1/dc/tool-c/ai/process(生成并执行,一步到位)
- GET /api/v1/dc/tool-c/ai/history/:sessionId(对话历史)

文档更新:
- 新增Day 3开发完成总结(770行)
- 新增复杂场景优化技术债务(TD-C-006)
- 更新工具C当前状态文档
- 更新技术债务清单

影响范围:
- backend/src/modules/dc/tool-c/*(新增2个文件,更新1个文件)
- backend/scripts/create-tool-c-ai-history-table.mjs(新增)
- backend/prisma/schema.prisma(新增DcToolCAiHistory模型)
- extraction_service/services/dc_executor.py(NaN序列化修复)
- docs/03-业务模块/DC-数据清洗整理/*(5份文档更新)

Breaking Changes: 无

总代码行数:+950行

Refs: #Tool-C-Day3
This commit is contained in:
2025-12-07 16:21:32 +08:00
parent 2348234013
commit f01981bf78
68 changed files with 6257 additions and 17 deletions

View File

@@ -31,3 +31,4 @@ COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.file_key IS 'OSS存储路径: dc/
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.columns IS '列名数组 ["age", "gender", "diagnosis"]';
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.expires_at IS '过期时间创建后10分钟';

View File

@@ -873,3 +873,30 @@ model DcToolCSession {
@@map("dc_tool_c_sessions")
@@schema("dc_schema")
}
// Tool C AI对话历史表
model DcToolCAiHistory {
id String @id @default(uuid())
sessionId String @map("session_id") // 关联Tool C Session
userId String @map("user_id")
role String @map("role") // user/assistant/system
content String @db.Text // 消息内容
// Tool C特有字段
generatedCode String? @db.Text @map("generated_code") // AI生成的代码
codeExplanation String? @db.Text @map("code_explanation") // 代码解释
executeStatus String? @map("execute_status") // pending/success/failed
executeResult Json? @map("execute_result") // 执行结果
executeError String? @db.Text @map("execute_error") // 错误信息
retryCount Int @default(0) @map("retry_count") // 重试次数
// LLM相关
model String? @map("model") // deepseek-v3/qwen3等
createdAt DateTime @default(now()) @map("created_at")
@@index([sessionId])
@@index([userId])
@@index([createdAt])
@@map("dc_tool_c_ai_history")
@@schema("dc_schema")
}

View File

@@ -180,3 +180,5 @@ function extractCodeBlocks(obj, blocks = []) {

View File

@@ -199,3 +199,5 @@ checkDCTables();

View File

@@ -0,0 +1,155 @@
/**
* 创建 Tool C AI对话历史表
*
* 执行方式node scripts/create-tool-c-ai-history-table.mjs
*/
import { PrismaClient } from '@prisma/client';
const prisma = new PrismaClient();
async function createAiHistoryTable() {
console.log('========================================');
console.log('开始创建 Tool C AI对话历史表');
console.log('========================================\n');
try {
// 1. 检查表是否已存在
console.log('[1/4] 检查表是否已存在...');
const checkResult = await prisma.$queryRawUnsafe(`
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'dc_schema'
AND table_name = 'dc_tool_c_ai_history'
) as exists
`);
const tableExists = checkResult[0].exists;
if (tableExists) {
console.log('✅ 表已存在: dc_schema.dc_tool_c_ai_history');
console.log('\n如需重新创建请手动执行: DROP TABLE dc_schema.dc_tool_c_ai_history CASCADE;\n');
return;
}
console.log('✅ 表不存在,准备创建\n');
// 2. 创建表
console.log('[2/4] 创建表 dc_tool_c_ai_history...');
await prisma.$executeRawUnsafe(`
CREATE TABLE dc_schema.dc_tool_c_ai_history (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
session_id VARCHAR(255) NOT NULL,
user_id VARCHAR(255) NOT NULL,
role VARCHAR(50) NOT NULL,
content TEXT NOT NULL,
-- Tool C特有字段
generated_code TEXT,
code_explanation TEXT,
execute_status VARCHAR(50),
execute_result JSONB,
execute_error TEXT,
retry_count INTEGER DEFAULT 0,
-- LLM相关
model VARCHAR(100),
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
`);
console.log('✅ 表创建成功\n');
// 3. 创建索引
console.log('[3/4] 创建索引...');
await prisma.$executeRawUnsafe(`
CREATE INDEX idx_dc_tool_c_ai_history_session_id
ON dc_schema.dc_tool_c_ai_history(session_id)
`);
await prisma.$executeRawUnsafe(`
CREATE INDEX idx_dc_tool_c_ai_history_user_id
ON dc_schema.dc_tool_c_ai_history(user_id)
`);
await prisma.$executeRawUnsafe(`
CREATE INDEX idx_dc_tool_c_ai_history_created_at
ON dc_schema.dc_tool_c_ai_history(created_at)
`);
console.log('✅ 索引创建成功\n');
// 4. 添加注释
console.log('[4/4] 添加表注释...');
await prisma.$executeRawUnsafe(`
COMMENT ON TABLE dc_schema.dc_tool_c_ai_history
IS 'Tool C (科研数据编辑器) AI对话历史表'
`);
await prisma.$executeRawUnsafe(`
COMMENT ON COLUMN dc_schema.dc_tool_c_ai_history.session_id
IS '关联Tool C Session ID'
`);
await prisma.$executeRawUnsafe(`
COMMENT ON COLUMN dc_schema.dc_tool_c_ai_history.generated_code
IS 'AI生成的Pandas代码'
`);
await prisma.$executeRawUnsafe(`
COMMENT ON COLUMN dc_schema.dc_tool_c_ai_history.execute_status
IS '执行状态: pending/success/failed'
`);
await prisma.$executeRawUnsafe(`
COMMENT ON COLUMN dc_schema.dc_tool_c_ai_history.retry_count
IS '自我修正重试次数'
`);
console.log('✅ 注释添加成功\n');
// 5. 验证表创建
console.log('========================================');
console.log('验证表结构');
console.log('========================================\n');
const columns = await prisma.$queryRawUnsafe(`
SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_schema = 'dc_schema'
AND table_name = 'dc_tool_c_ai_history'
ORDER BY ordinal_position
`);
console.log('表结构:');
console.table(columns);
const indexes = await prisma.$queryRawUnsafe(`
SELECT indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'dc_schema'
AND tablename = 'dc_tool_c_ai_history'
`);
console.log('\n索引:');
console.table(indexes);
console.log('\n========================================');
console.log('🎉 Tool C AI对话历史表创建成功');
console.log('========================================\n');
console.log('表名: dc_schema.dc_tool_c_ai_history');
console.log(`列数: ${columns.length}`);
console.log(`索引数: ${indexes.length}\n`);
} catch (error) {
console.error('\n❌ 创建表失败:', error.message);
console.error('\n详细错误:');
console.error(error);
process.exit(1);
} finally {
await prisma.$disconnect();
}
}
// 执行
createAiHistoryTable()
.then(() => {
console.log('脚本执行完成');
process.exit(0);
})
.catch((error) => {
console.error('脚本执行失败:', error);
process.exit(1);
});

View File

@@ -0,0 +1,142 @@
/**
* 创建 Tool C Session 表
*
* 执行方式node scripts/create-tool-c-table.js
*/
const { PrismaClient } = require('@prisma/client');
const fs = require('fs');
const path = require('path');
const prisma = new PrismaClient();
async function createToolCTable() {
console.log('========================================');
console.log('开始创建 Tool C Session 表');
console.log('========================================\n');
try {
// 1. 检查表是否已存在
console.log('[1/4] 检查表是否已存在...');
const checkResult = await prisma.$queryRawUnsafe(`
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'dc_schema'
AND table_name = 'dc_tool_c_sessions'
) as exists
`);
const tableExists = checkResult[0].exists;
if (tableExists) {
console.log('✅ 表已存在: dc_schema.dc_tool_c_sessions');
console.log('\n是否需要重新创建(这将删除现有数据)');
console.log('如需重新创建,请手动执行: DROP TABLE dc_schema.dc_tool_c_sessions CASCADE;\n');
return;
}
console.log('✅ 表不存在,准备创建\n');
// 2. 创建表
console.log('[2/4] 创建表 dc_tool_c_sessions...');
await prisma.$executeRawUnsafe(`
CREATE TABLE dc_schema.dc_tool_c_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id VARCHAR(255) NOT NULL,
file_name VARCHAR(500) NOT NULL,
file_key VARCHAR(500) NOT NULL,
total_rows INTEGER NOT NULL,
total_cols INTEGER NOT NULL,
columns JSONB NOT NULL,
encoding VARCHAR(50),
file_size INTEGER NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL
)
`);
console.log('✅ 表创建成功\n');
// 3. 创建索引
console.log('[3/4] 创建索引...');
await prisma.$executeRawUnsafe(`
CREATE INDEX idx_dc_tool_c_sessions_user_id ON dc_schema.dc_tool_c_sessions(user_id)
`);
await prisma.$executeRawUnsafe(`
CREATE INDEX idx_dc_tool_c_sessions_expires_at ON dc_schema.dc_tool_c_sessions(expires_at)
`);
console.log('✅ 索引创建成功\n');
// 4. 添加注释
console.log('[4/4] 添加表注释...');
await prisma.$executeRawUnsafe(`
COMMENT ON TABLE dc_schema.dc_tool_c_sessions IS 'Tool C (科研数据编辑器) Session会话表'
`);
await prisma.$executeRawUnsafe(`
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.file_key IS 'OSS存储路径: dc/tool-c/sessions/{timestamp}-{fileName}'
`);
await prisma.$executeRawUnsafe(`
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.columns IS '列名数组 ["age", "gender", "diagnosis"]'
`);
await prisma.$executeRawUnsafe(`
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.expires_at IS '过期时间创建后10分钟'
`);
console.log('✅ 注释添加成功\n');
// 5. 验证表创建
console.log('========================================');
console.log('验证表结构');
console.log('========================================\n');
const columns = await prisma.$queryRawUnsafe(`
SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_schema = 'dc_schema'
AND table_name = 'dc_tool_c_sessions'
ORDER BY ordinal_position
`);
console.log('表结构:');
console.table(columns);
const indexes = await prisma.$queryRawUnsafe(`
SELECT indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'dc_schema'
AND tablename = 'dc_tool_c_sessions'
`);
console.log('\n索引:');
console.table(indexes);
console.log('\n========================================');
console.log('🎉 Tool C Session 表创建成功!');
console.log('========================================\n');
console.log('表名: dc_schema.dc_tool_c_sessions');
console.log(`列数: ${columns.length}`);
console.log(`索引数: ${indexes.length}\n`);
} catch (error) {
console.error('\n❌ 创建表失败:', error.message);
console.error('\n详细错误:');
console.error(error);
process.exit(1);
} finally {
await prisma.$disconnect();
}
}
// 执行
createToolCTable()
.then(() => {
console.log('脚本执行完成');
process.exit(0);
})
.catch((error) => {
console.error('脚本执行失败:', error);
process.exit(1);
});

View File

@@ -136,3 +136,4 @@ createToolCTable()
process.exit(1);
});

View File

@@ -303,3 +303,5 @@ runTests().catch((error) => {

View File

@@ -282,3 +282,5 @@ Content-Type: application/json

View File

@@ -361,3 +361,5 @@ export class ExcelExporter {

View File

@@ -6,6 +6,7 @@
import { FastifyInstance } from 'fastify';
import { registerToolBRoutes } from './tool-b/routes/index.js';
import { toolCRoutes } from './tool-c/routes/index.js';
import { templateService } from './tool-b/services/TemplateService.js';
import { logger } from '../../common/logging/index.js';
@@ -20,6 +21,11 @@ export async function registerDCRoutes(fastify: FastifyInstance) {
await registerToolBRoutes(instance);
}, { prefix: '/api/v1/dc/tool-b' });
// 注册Tool C路由科研数据编辑器
await fastify.register(async (instance) => {
await toolCRoutes(instance);
}, { prefix: '/api/v1/dc/tool-c' });
logger.info('[DC] DC module routes registered');
}

View File

@@ -218,3 +218,5 @@ export const conflictDetectionService = new ConflictDetectionService();

View File

@@ -246,3 +246,5 @@ export const templateService = new TemplateService();

View File

@@ -169,3 +169,4 @@ curl -X POST http://localhost:3000/api/v1/dc/tool-c/test/execute \
- [ ] AI代码生成服务LLMFactory集成
- [ ] 前端基础框架搭建

View File

@@ -0,0 +1,256 @@
/**
* AI代码生成控制器
*
* API端点
* - POST /ai/generate 生成代码(不执行)
* - POST /ai/execute 执行代码
* - POST /ai/process 生成并执行(带重试)
* - GET /ai/history/:sessionId 获取对话历史
*
* @module AIController
*/
import { FastifyRequest, FastifyReply } from 'fastify';
import { logger } from '../../../../common/logging/index.js';
import { aiCodeService } from '../services/AICodeService.js';
// ==================== 请求参数类型定义 ====================
interface GenerateCodeBody {
sessionId: string;
message: string;
}
interface ExecuteCodeBody {
sessionId: string;
code: string;
messageId: string;
}
interface ProcessBody {
sessionId: string;
message: string;
maxRetries?: number;
}
interface HistoryParams {
sessionId: string;
}
// ==================== 控制器 ====================
export class AIController {
/**
* POST /api/v1/dc/tool-c/ai/generate
* 生成代码(不执行)
*/
async generateCode(request: FastifyRequest, reply: FastifyReply) {
try {
const { sessionId, message } = request.body as GenerateCodeBody;
logger.info(`[AIController] 收到生成代码请求: sessionId=${sessionId}`);
// 参数验证
if (!sessionId || !message) {
return reply.code(400).send({
success: false,
error: '缺少必要参数sessionId 或 message'
});
}
if (message.trim().length === 0) {
return reply.code(400).send({
success: false,
error: '消息内容不能为空'
});
}
// 生成代码
const result = await aiCodeService.generateCode(sessionId, message);
logger.info(`[AIController] 代码生成成功: messageId=${result.messageId}`);
return reply.code(200).send({
success: true,
message: 'AI代码生成成功',
data: {
code: result.code,
explanation: result.explanation,
messageId: result.messageId
}
});
} catch (error: any) {
logger.error(`[AIController] generateCode失败: ${error.message}`);
return reply.code(500).send({
success: false,
error: error.message || 'AI代码生成失败请重试'
});
}
}
/**
* POST /api/v1/dc/tool-c/ai/execute
* 执行代码
*/
async executeCode(request: FastifyRequest, reply: FastifyReply) {
try {
const { sessionId, code, messageId } = request.body as ExecuteCodeBody;
logger.info(`[AIController] 收到执行代码请求: messageId=${messageId}`);
// 参数验证
if (!sessionId || !code || !messageId) {
return reply.code(400).send({
success: false,
error: '缺少必要参数sessionId、code 或 messageId'
});
}
// 执行代码
const result = await aiCodeService.executeCode(sessionId, code, messageId);
if (result.success) {
logger.info(`[AIController] 代码执行成功: messageId=${messageId}`);
return reply.code(200).send({
success: true,
message: '代码执行成功',
data: {
result: result.result,
newDataPreview: result.newDataPreview
}
});
} else {
logger.warn(`[AIController] 代码执行失败: ${result.error}`);
return reply.code(200).send({
success: false,
error: result.error || '代码执行失败',
data: null
});
}
} catch (error: any) {
logger.error(`[AIController] executeCode失败: ${error.message}`);
return reply.code(500).send({
success: false,
error: error.message || '代码执行失败,请重试'
});
}
}
/**
* POST /api/v1/dc/tool-c/ai/process
* 生成并执行(一步到位,带重试)
*/
async process(request: FastifyRequest, reply: FastifyReply) {
try {
const { sessionId, message, maxRetries = 3 } = request.body as ProcessBody;
logger.info(`[AIController] 收到处理请求: sessionId=${sessionId}, maxRetries=${maxRetries}`);
// 参数验证
if (!sessionId || !message) {
return reply.code(400).send({
success: false,
error: '缺少必要参数sessionId 或 message'
});
}
if (message.trim().length === 0) {
return reply.code(400).send({
success: false,
error: '消息内容不能为空'
});
}
if (maxRetries < 1 || maxRetries > 5) {
return reply.code(400).send({
success: false,
error: '重试次数必须在1-5之间'
});
}
// 生成并执行(带重试)
const result = await aiCodeService.generateAndExecute(
sessionId,
message,
maxRetries
);
logger.info(`[AIController] 处理成功: 重试${result.retryCount}次后成功`);
return reply.code(200).send({
success: true,
message: `代码执行成功${result.retryCount > 0 ? `(重试${result.retryCount}次)` : ''}`,
data: {
code: result.code,
explanation: result.explanation,
executeResult: result.executeResult,
retryCount: result.retryCount,
messageId: result.messageId
}
});
} catch (error: any) {
logger.error(`[AIController] process失败: ${error.message}`);
return reply.code(500).send({
success: false,
error: error.message || '处理失败,请重试'
});
}
}
/**
* GET /api/v1/dc/tool-c/ai/history/:sessionId
* 获取对话历史
*/
async getHistory(
request: FastifyRequest<{ Params: HistoryParams; Querystring: { limit?: string } }>,
reply: FastifyReply
) {
try {
const { sessionId } = request.params;
const limit = request.query.limit ? parseInt(request.query.limit) : 10;
logger.info(`[AIController] 获取对话历史: sessionId=${sessionId}, limit=${limit}`);
// 参数验证
if (!sessionId) {
return reply.code(400).send({
success: false,
error: '缺少必要参数sessionId'
});
}
if (limit < 1 || limit > 50) {
return reply.code(400).send({
success: false,
error: '历史记录数量必须在1-50之间'
});
}
// 获取历史
const history = await aiCodeService.getHistory(sessionId, limit);
logger.info(`[AIController] 获取历史成功: ${history.length}`);
return reply.code(200).send({
success: true,
data: {
sessionId,
history,
count: history.length
}
});
} catch (error: any) {
logger.error(`[AIController] getHistory失败: ${error.message}`);
return reply.code(500).send({
success: false,
error: error.message || '获取对话历史失败'
});
}
}
}
// ==================== 导出单例实例 ====================
export const aiController = new AIController();

View File

@@ -7,6 +7,7 @@
import { FastifyInstance } from 'fastify';
import { testController } from '../controllers/TestController.js';
import { sessionController } from '../controllers/SessionController.js';
import { aiController } from '../controllers/AIController.js';
export async function toolCRoutes(fastify: FastifyInstance) {
// ==================== 测试路由Day 1 ====================
@@ -57,5 +58,27 @@ export async function toolCRoutes(fastify: FastifyInstance) {
fastify.post('/sessions/:id/heartbeat', {
handler: sessionController.updateHeartbeat.bind(sessionController),
});
// ==================== AI代码生成路由Day 3 ====================
// 生成代码(不执行)
fastify.post('/ai/generate', {
handler: aiController.generateCode.bind(aiController),
});
// 执行代码
fastify.post('/ai/execute', {
handler: aiController.executeCode.bind(aiController),
});
// 生成并执行(一步到位,带重试)
fastify.post('/ai/process', {
handler: aiController.process.bind(aiController),
});
// 获取对话历史
fastify.get('/ai/history/:sessionId', {
handler: aiController.getHistory.bind(aiController),
});
}

View File

@@ -0,0 +1,549 @@
/**
* AI代码生成服务
*
* 功能:
* - 使用LLM生成Pandas数据清洗代码
* - 执行生成的代码
* - 自我修正最多3次重试
* - 管理对话历史
*
* @module AICodeService
*/
import { logger } from '../../../../common/logging/index.js';
import { prisma } from '../../../../config/database.js';
import { LLMFactory } from '../../../../common/llm/adapters/LLMFactory.js';
import { ModelType, Message } from '../../../../common/llm/adapters/types.js';
import { sessionService } from './SessionService.js';
import { pythonExecutorService } from './PythonExecutorService.js';
// ==================== 类型定义 ====================
interface SessionData {
id: string;
fileName: string;
totalRows: number;
totalCols: number;
columns: string[];
}
interface GenerateCodeResult {
code: string;
explanation: string;
messageId: string;
}
interface ExecuteCodeResult {
success: boolean;
result?: any;
error?: string;
newDataPreview?: any[];
}
interface ProcessResult extends GenerateCodeResult {
executeResult: ExecuteCodeResult;
retryCount: number;
}
// ==================== AI代码生成服务 ====================
export class AICodeService {
/**
* 生成Pandas代码
* @param sessionId - Tool C Session ID
* @param userMessage - 用户自然语言需求
* @returns { code, explanation, messageId }
*/
async generateCode(
sessionId: string,
userMessage: string
): Promise<GenerateCodeResult> {
try {
logger.info(`[AICodeService] 生成代码: sessionId=${sessionId}`);
// 1. 获取Session信息数据集元数据
const session = await sessionService.getSession(sessionId);
// 2. 构建System Prompt含10个Few-shot示例
const systemPrompt = this.buildSystemPrompt({
id: session.id,
fileName: session.fileName,
totalRows: session.totalRows,
totalCols: session.totalCols,
columns: session.columns
});
// 3. 获取对话历史最近5轮
const history = await this.getHistory(sessionId, 5);
// 4. 调用LLM复用LLMFactory
const llm = LLMFactory.getAdapter('deepseek-v3' as ModelType);
const response = await llm.chat([
{ role: 'system', content: systemPrompt },
...history,
{ role: 'user', content: userMessage }
], {
temperature: 0.1, // 低温度,确保代码准确性
maxTokens: 2000, // 足够生成代码+解释
topP: 0.9
});
logger.info(`[AICodeService] LLM响应成功开始解析...`);
// 5. 解析AI回复提取code和explanation
const parsed = this.parseAIResponse(response.content);
// 6. 保存到数据库
const messageId = await this.saveMessages(
sessionId,
session.userId,
userMessage,
parsed.code,
parsed.explanation
);
logger.info(`[AICodeService] 代码生成成功: messageId=${messageId}`);
return {
code: parsed.code,
explanation: parsed.explanation,
messageId
};
} catch (error: any) {
logger.error(`[AICodeService] 生成代码失败: ${error.message}`);
throw error;
}
}
/**
* 执行Python代码
* @param sessionId - Tool C Session ID
* @param code - Python代码
* @param messageId - 关联的消息ID
* @returns { success, result, newDataPreview }
*/
async executeCode(
sessionId: string,
code: string,
messageId: string
): Promise<ExecuteCodeResult> {
try {
logger.info(`[AICodeService] 执行代码: messageId=${messageId}`);
// 1. 从Session获取完整数据
const fullData = await sessionService.getFullData(sessionId);
logger.info(`[AICodeService] 获取Session数据: ${fullData.length}`);
// 2. 调用Python服务执行
const result = await pythonExecutorService.executeCode(
fullData,
code
);
// 2. 更新消息状态
// @ts-ignore - DcToolCAiHistory模型
await prisma.dcToolCAiHistory.update({
where: { id: messageId },
data: {
executeStatus: result.success ? 'success' : 'failed',
executeResult: result.result_data ? JSON.parse(JSON.stringify({ data: result.result_data })) : undefined,
executeError: result.error || undefined
}
});
// 4. 如果成功获取新数据预览前50行
if (result.success && result.result_data) {
const preview = Array.isArray(result.result_data)
? result.result_data.slice(0, 50)
: result.result_data;
logger.info(`[AICodeService] 代码执行成功`);
return {
success: true,
result: result.result_data,
newDataPreview: preview
};
}
logger.warn(`[AICodeService] 代码执行失败: ${result.error}`);
return {
success: false,
error: result.error || '执行失败,未知错误'
};
} catch (error: any) {
logger.error(`[AICodeService] 执行代码异常: ${error.message}`);
// 更新为失败状态
// @ts-ignore - DcToolCAiHistory模型
await prisma.dcToolCAiHistory.update({
where: { id: messageId },
data: {
executeStatus: 'failed',
executeError: error.message
}
});
return {
success: false,
error: error.message
};
}
}
/**
* 生成并执行(带自我修正)
* @param sessionId - Tool C Session ID
* @param userMessage - 用户需求
* @param maxRetries - 最大重试次数默认3
* @returns { code, explanation, executeResult, retryCount }
*/
async generateAndExecute(
sessionId: string,
userMessage: string,
maxRetries: number = 3
): Promise<ProcessResult> {
let attempt = 0;
let lastError: string | null = null;
let generated: GenerateCodeResult | null = null;
while (attempt < maxRetries) {
try {
logger.info(`[AICodeService] 尝试 ${attempt + 1}/${maxRetries}`);
// 构建带错误反馈的提示词
const enhancedMessage = attempt === 0
? userMessage
: `${userMessage}\n\n上次执行错误${lastError}\n请修正代码确保代码正确且符合Pandas语法。`;
// 生成代码
generated = await this.generateCode(sessionId, enhancedMessage);
// 执行代码
const executeResult = await this.executeCode(
sessionId,
generated.code,
generated.messageId
);
if (executeResult.success) {
// ✅ 成功
logger.info(`[AICodeService] 执行成功(尝试${attempt + 1}次)`);
// 更新重试次数
// @ts-ignore - DcToolCAiHistory模型
await prisma.dcToolCAiHistory.update({
where: { id: generated.messageId },
data: { retryCount: attempt }
});
return {
...generated,
executeResult,
retryCount: attempt
};
}
// ❌ 失败,准备重试
lastError = executeResult.error || '未知错误';
attempt++;
logger.warn(`[AICodeService] 执行失败(尝试${attempt}/${maxRetries}: ${lastError}`);
} catch (error: any) {
logger.error(`[AICodeService] 异常: ${error.message}`);
lastError = error.message;
attempt++;
}
}
// 3次仍失败
throw new Error(
`代码执行失败(已重试${maxRetries}次)。最后错误:${lastError}` +
`建议:请调整需求描述或检查数据列名是否正确。`
);
}
/**
* 获取对话历史
* @param sessionId - Tool C Session ID
* @param limit - 最近N轮对话默认5轮即10条消息
* @returns 消息列表
*/
async getHistory(sessionId: string, limit: number = 5): Promise<Message[]> {
try {
// @ts-ignore - DcToolCAiHistory模型
const records = await prisma.dcToolCAiHistory.findMany({
where: { sessionId },
orderBy: { createdAt: 'desc' },
take: limit * 2, // user + assistant 成对
select: {
role: true,
content: true
}
});
// 反转顺序(最旧的在前)
return records.reverse().map((r: any) => ({
role: r.role as 'user' | 'assistant' | 'system',
content: r.content
}));
} catch (error: any) {
logger.error(`[AICodeService] 获取历史失败: ${error.message}`);
return [];
}
}
// ==================== 辅助方法 ====================
/**
* 构建System Prompt含10个Few-shot示例
* @private
*/
private buildSystemPrompt(session: SessionData): string {
return `你是医疗科研数据清洗专家负责生成Pandas代码来清洗整理数据。
## 当前数据集信息
- 文件名: ${session.fileName}
- 行数: ${session.totalRows}
- 列数: ${session.totalCols}
- 列名: ${session.columns.join(', ')}
## 执行环境(重要)
**已预导入的库请直接使用不要再import**
- pandas 已导入为 pd
- numpy 已导入为 np
- df 变量已加载当前数据集
**不可用的库:**
- sklearn未安装请使用pandas/numpy替代方案
- scipy未安装
- 其他第三方库
**示例:直接使用,无需导入**
\`\`\`python
# ✅ 正确:直接使用预导入的库
df['age_clean'] = df['age'].fillna(df['age'].median())
df['group'] = np.where(df['age'] > 60, '老年', '非老年')
# ❌ 错误:不要再导入
import pandas as pd # 会报错!
import numpy as np # 会报错!
import sklearn # 未安装,会报错!
\`\`\`
## 安全规则(强制)
1. 只能操作df变量不能修改其他变量
2. **禁止任何import语句**pandas和numpy已预导入
3. 禁止使用eval、exec、__import__等危险函数
4. 必须进行异常处理
5. 返回格式必须是JSON: {"code": "...", "explanation": "..."}
## Few-shot示例
### 示例1: 统一缺失值标记
用户: 把所有代表缺失的符号(-、不详、NA、N/A统一替换为标准空值
代码:
\`\`\`python
try:
df = df.replace(['-', '不详', 'NA', 'N/A', '\\\\', '未查'], np.nan)
print(f'缺失值标记统一完成,当前缺失值数量: {df.isna().sum().sum()}')
except Exception as e:
print(f'处理错误: {e}')
\`\`\`
说明: 将多种缺失值表示统一为NaN便于后续统计分析
### 示例2: 数值列清洗
用户: 把肌酐列里的非数字符号去掉,<0.1按0.05处理,转为数值类型
代码:
\`\`\`python
df['creatinine'] = df['creatinine'].astype(str).str.replace('>', '').str.replace('<', '')
df.loc[df['creatinine'] == '0.1', 'creatinine'] = '0.05'
df['creatinine'] = pd.to_numeric(df['creatinine'], errors='coerce')
\`\`\`
说明: 检验科数据常含符号,需清理后才能计算
### 示例3: 分类变量编码
用户: 把性别列转为数字,男=1女=0
代码:
\`\`\`python
df['gender_code'] = df['gender'].map({'男': 1, '女': 0})
\`\`\`
说明: 将文本分类变量转为数值,便于统计建模
### 示例4: 连续变量分箱
用户: 把年龄按18岁、60岁分为未成年、成年、老年三组
代码:
\`\`\`python
df['age_group'] = pd.cut(df['age'], bins=[0, 18, 60, 120], labels=['未成年', '成年', '老年'], right=False)
\`\`\`
说明: 将连续变量离散化,用于分层分析或卡方检验
### 示例5: BMI计算与分类
用户: 根据身高(cm)和体重(kg)计算BMI并标记BMI≥28为肥胖
代码:
\`\`\`python
df['BMI'] = df['weight'] / (df['height'] / 100) ** 2
df['obesity'] = df['BMI'].apply(lambda x: '肥胖' if x >= 28 else '正常')
\`\`\`
说明: 临床常用的体质指标计算和分类
### 示例6: 日期计算
用户: 根据入院日期和出院日期计算住院天数
代码:
\`\`\`python
df['admission_date'] = pd.to_datetime(df['admission_date'])
df['discharge_date'] = pd.to_datetime(df['discharge_date'])
df['length_of_stay'] = (df['discharge_date'] - df['admission_date']).dt.days
\`\`\`
说明: 医疗数据常需计算时间间隔(住院天数、随访时间等)
### 示例7: 条件筛选(入组标准)
用户: 筛选出年龄≥18岁、诊断为糖尿病、且血糖≥7.0的患者
代码:
\`\`\`python
df_selected = df[(df['age'] >= 18) & (df['diagnosis'] == '糖尿病') & (df['glucose'] >= 7.0)]
\`\`\`
说明: 临床研究常需根据入组/排除标准筛选病例
### 示例8: 简单缺失值填补
用户: 用中位数填补BMI列的缺失值
代码:
\`\`\`python
bmi_median = df['BMI'].median()
df['BMI'] = df['BMI'].fillna(bmi_median)
\`\`\`
说明: 简单填补适用于缺失率<5%且MCAR完全随机缺失的情况
### 示例9: 智能多列缺失值填补
用户: 对BMI、年龄、肌酐列的缺失值进行智能填补
代码:
\`\`\`python
try:
# 检查列是否存在
cols = ['BMI', 'age', 'creatinine']
missing_cols = [c for c in cols if c not in df.columns]
if missing_cols:
print(f'警告:以下列不存在: {missing_cols}')
else:
# 转换为数值类型
for col in cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 根据列特性选择填补策略
df['age'] = df['age'].fillna(df['age'].median()) # 年龄用中位数
df['BMI'] = df['BMI'].fillna(df.groupby('gender')['BMI'].transform('median')) # BMI按性别分组填补
df['creatinine'] = df['creatinine'].fillna(df['creatinine'].mean()) # 肌酐用均值
print('缺失值填补完成')
print(f'年龄缺失: {df["age"].isna().sum()}')
print(f'BMI缺失: {df["BMI"].isna().sum()}')
print(f'肌酐缺失: {df["creatinine"].isna().sum()}')
except Exception as e:
print(f'填补错误: {e}')
\`\`\`
说明: 根据医学变量特性选择不同填补策略年龄用中位数稳健BMI按性别分组考虑性别差异肌酐用均值
### 示例10: 智能去重
用户: 按患者ID去重保留检查日期最新的记录
代码:
\`\`\`python
df['check_date'] = pd.to_datetime(df['check_date'])
df = df.sort_values('check_date').drop_duplicates(subset=['patient_id'], keep='last')
\`\`\`
说明: 先按日期排序,再去重保留最后一条(最新)
## 用户当前请求
请根据以上示例和当前数据集信息生成代码并解释。返回JSON格式{"code": "...", "explanation": "..."}`;
}
/**
* 解析AI回复提取code和explanation
* @private
*/
private parseAIResponse(content: string): { code: string; explanation: string } {
try {
// 方法1尝试解析JSON
const json = JSON.parse(content);
if (json.code && json.explanation) {
return { code: json.code, explanation: json.explanation };
}
} catch {
// 方法2正则提取代码块
const codeMatch = content.match(/```python\n([\s\S]+?)\n```/);
const code = codeMatch ? codeMatch[1].trim() : '';
// 提取解释(代码块之外的文本)
let explanation = content.replace(/```python[\s\S]+?```/g, '').trim();
// 如果没有单独的解释尝试提取JSON中的explanation
try {
const jsonMatch = content.match(/\{[\s\S]*"explanation":\s*"([^"]+)"[\s\S]*\}/);
if (jsonMatch) {
explanation = jsonMatch[1];
}
} catch {
// 忽略
}
if (code) {
return { code, explanation: explanation || '代码已生成' };
}
}
logger.error(`[AICodeService] AI回复格式错误: ${content.substring(0, 200)}`);
throw new Error('AI回复格式错误无法提取代码。请重试。');
}
/**
* 保存消息到数据库
* @private
*/
private async saveMessages(
sessionId: string,
userId: string,
userMessage: string,
code: string,
explanation: string
): Promise<string> {
try {
// 保存用户消息
// @ts-ignore - DcToolCAiHistory模型
await prisma.dcToolCAiHistory.create({
data: {
sessionId,
userId,
role: 'user',
content: userMessage
}
});
// 保存AI回复
// @ts-ignore - DcToolCAiHistory模型
const assistantMessage = await prisma.dcToolCAiHistory.create({
data: {
sessionId,
userId,
role: 'assistant',
content: explanation,
generatedCode: code,
codeExplanation: explanation,
executeStatus: 'pending',
model: 'deepseek-v3'
}
});
return assistantMessage.id;
} catch (error: any) {
logger.error(`[AICodeService] 保存消息失败: ${error.message}`);
throw error;
}
}
}
// ==================== 导出单例实例 ====================
export const aiCodeService = new AICodeService();

View File

@@ -26,3 +26,5 @@ Write-Host "✅ 完成!" -ForegroundColor Green

View File

@@ -380,3 +380,4 @@ runAllTests()
process.exit(1);
});

View File

@@ -0,0 +1,341 @@
/**
* Tool C Day 3 测试脚本
*
* 测试内容:
* 1. 10个Few-shot示例场景测试
* 2. AI自我修正机制测试重试
* 3. 端到端测试
*
* 前提:
* - 需要先创建一个Session使用Day 2的upload接口
* - 需要Python服务运行端口8000
* - 需要后端服务运行端口3000
*
* 执行方式node test-tool-c-day3.mjs
*/
import axios from 'axios';
import FormData from 'form-data';
import * as XLSX from 'xlsx';
const BASE_URL = 'http://localhost:3000';
const API_PREFIX = '/api/v1/dc/tool-c';
let testSessionId = null;
// ==================== 辅助函数 ====================
function printSection(title) {
console.log('\n' + '='.repeat(70));
console.log(` ${title}`);
console.log('='.repeat(70) + '\n');
}
function printSuccess(message) {
console.log('✅ ' + message);
}
function printError(message) {
console.log('❌ ' + message);
}
function printInfo(message) {
console.log(' ' + message);
}
// ==================== 准备测试Session ====================
async function createTestSession() {
printSection('准备创建测试Session');
try {
// 创建测试Excel数据
const testData = [
{ patient_id: 'P001', name: '张三', age: 25, gender: '男', diagnosis: '感冒', sbp: 120, dbp: 80, weight: 70, height: 175, BMI: '', creatinine: '>100', check_date: '2024-01-01' },
{ patient_id: 'P002', name: '李四', age: 65, gender: '女', diagnosis: '高血压', sbp: 150, dbp: 95, weight: 65, height: 160, BMI: '', creatinine: '<0.1', check_date: '2024-01-05' },
{ patient_id: 'P003', name: '王五', age: 45, gender: '男', diagnosis: '糖尿病', sbp: 135, dbp: 85, weight: 80, height: 170, BMI: '', creatinine: '85', check_date: '2024-01-03' },
{ patient_id: 'P004', name: '赵六', age: 70, gender: '女', diagnosis: '冠心病', sbp: 160, dbp: 100, weight: 60, height: 155, BMI: '', creatinine: '120', check_date: '2024-01-10' },
{ patient_id: 'P005', name: '钱七', age: 35, gender: '男', diagnosis: '胃炎', sbp: 110, dbp: 70, weight: 75, height: 180, BMI: '', creatinine: '-', check_date: '2024-01-08' },
{ patient_id: 'P003', name: '王五', age: 45, gender: '男', diagnosis: '糖尿病', sbp: 138, dbp: 88, weight: 82, height: 170, BMI: '', creatinine: '88', check_date: '2024-01-12' }, // 重复ID日期更新
];
// 生成Excel
const ws = XLSX.utils.json_to_sheet(testData);
const wb = XLSX.utils.book_new();
XLSX.utils.book_append_sheet(wb, ws, 'Sheet1');
const excelBuffer = XLSX.write(wb, { type: 'buffer', bookType: 'xlsx' });
// 上传创建Session
const form = new FormData();
form.append('file', excelBuffer, {
filename: 'test-medical-data-day3.xlsx',
contentType: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
});
const response = await axios.post(
`${BASE_URL}${API_PREFIX}/sessions/upload`,
form,
{ headers: form.getHeaders(), timeout: 10000 }
);
if (response.data.success) {
testSessionId = response.data.data.sessionId;
printSuccess(`Session创建成功: ${testSessionId}`);
printInfo(`数据: ${testData.length}行 x ${Object.keys(testData[0]).length}`);
return true;
} else {
printError('Session创建失败');
return false;
}
} catch (error) {
printError('Session创建异常: ' + error.message);
return false;
}
}
// ==================== AI测试函数 ====================
async function testAIGenerate(testName, userMessage, shouldSucceed = true) {
printSection(`测试: ${testName}`);
try {
printInfo(`用户需求: ${userMessage}`);
const response = await axios.post(
`${BASE_URL}${API_PREFIX}/ai/generate`,
{
sessionId: testSessionId,
message: userMessage
},
{ timeout: 30000 } // AI调用可能需要较长时间
);
if (response.data.success) {
printSuccess('AI生成代码成功');
console.log('\n生成的代码:');
console.log('```python');
console.log(response.data.data.code);
console.log('```\n');
console.log('解释:', response.data.data.explanation);
console.log('MessageID:', response.data.data.messageId);
return { success: true, data: response.data.data };
} else {
if (shouldSucceed) {
printError('AI生成失败: ' + response.data.error);
} else {
printInfo('预期失败: ' + response.data.error);
}
return { success: false, error: response.data.error };
}
} catch (error) {
printError('AI生成异常: ' + (error.response?.data?.error || error.message));
return { success: false, error: error.message };
}
}
async function testAIProcess(testName, userMessage) {
printSection(`测试(一步到位): ${testName}`);
try {
printInfo(`用户需求: ${userMessage}`);
const response = await axios.post(
`${BASE_URL}${API_PREFIX}/ai/process`,
{
sessionId: testSessionId,
message: userMessage,
maxRetries: 3
},
{ timeout: 60000 } // 带重试可能需要更长时间
);
if (response.data.success) {
printSuccess(`执行成功${response.data.data.retryCount > 0 ? `(重试${response.data.data.retryCount}次)` : ''}`);
console.log('\n生成的代码:');
console.log('```python');
console.log(response.data.data.code);
console.log('```\n');
console.log('解释:', response.data.data.explanation);
if (response.data.data.executeResult.success) {
printSuccess('代码执行成功');
console.log('数据预览前5行:');
console.log(JSON.stringify(response.data.data.executeResult.newDataPreview?.slice(0, 5), null, 2));
}
return { success: true, data: response.data.data };
} else {
printError('处理失败: ' + response.data.error);
return { success: false, error: response.data.error };
}
} catch (error) {
printError('处理异常: ' + (error.response?.data?.error || error.message));
return { success: false, error: error.message };
}
}
// ==================== 主测试函数 ====================
async function runAllTests() {
console.log('\n' + '🚀'.repeat(35));
console.log(' Tool C Day 3 测试 - AI代码生成');
console.log('🚀'.repeat(35));
const results = {};
try {
// 0. 准备测试Session
const sessionCreated = await createTestSession();
if (!sessionCreated) {
printError('测试Session创建失败无法继续');
return;
}
await new Promise(resolve => setTimeout(resolve, 2000));
// ==================== 10个Few-shot示例测试 ====================
// 测试1: 统一缺失值标记
let result = await testAIProcess(
'示例1: 统一缺失值标记',
'把所有代表缺失的符号(-、不详、NA、N/A统一替换为标准空值'
);
results['示例1-缺失值'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试2: 数值列清洗
result = await testAIProcess(
'示例2: 数值列清洗',
'把creatinine列里的非数字符号去掉<0.1按0.05处理,转为数值类型'
);
results['示例2-数值清洗'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试3: 分类变量编码
result = await testAIProcess(
'示例3: 分类变量编码',
'把gender列转为数字男=1女=0'
);
results['示例3-编码'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试4: 连续变量分箱
result = await testAIProcess(
'示例4: 连续变量分箱',
'把age列按18岁、60岁分为未成年、成年、老年三组'
);
results['示例4-分箱'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试5: BMI计算
result = await testAIProcess(
'示例5: BMI计算',
'根据weight和height计算BMI并标记BMI≥28为肥胖'
);
results['示例5-BMI'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试6: 条件筛选
result = await testAIProcess(
'示例6: 条件筛选',
'筛选出年龄≥60岁、且sbp≥140的患者'
);
results['示例6-筛选'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试7: 智能去重
result = await testAIProcess(
'示例7: 智能去重',
'按patient_id去重保留check_date最新的记录'
);
results['示例7-去重'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试8: 中位数填补(简化版,跳过多重插补)
result = await testAIProcess(
'示例8: 缺失值填补',
'用age列的中位数填补age列的缺失值'
);
results['示例8-填补'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试9: 统计汇总
result = await testAIProcess(
'示例9: 统计汇总',
'按diagnosis分组统计每个诊断的平均年龄和患者数量'
);
results['示例9-统计'] = result.success;
await new Promise(resolve => setTimeout(resolve, 3000));
// 测试10: 复杂计算
result = await testAIProcess(
'示例10: 复杂计算',
'根据sbp判断血压分类正常(<140)、高血压I级(140-159)、高血压II级(≥160)'
);
results['示例10-分类'] = result.success;
// ==================== 对话历史测试 ====================
printSection('测试: 获取对话历史');
try {
const historyResponse = await axios.get(
`${BASE_URL}${API_PREFIX}/ai/history/${testSessionId}?limit=5`
);
if (historyResponse.data.success) {
printSuccess(`获取历史成功: ${historyResponse.data.data.count}`);
results['对话历史'] = true;
} else {
printError('获取历史失败');
results['对话历史'] = false;
}
} catch (error) {
printError('获取历史异常: ' + error.message);
results['对话历史'] = false;
}
} catch (error) {
printError('测试过程中发生异常: ' + error.message);
console.error(error);
}
// 汇总结果
printSection('测试结果汇总');
let passed = 0;
let total = 0;
for (const [testName, result] of Object.entries(results)) {
total++;
if (result) {
passed++;
console.log(`${testName.padEnd(20)}: ✅ 通过`);
} else {
console.log(`${testName.padEnd(20)}: ❌ 失败`);
}
}
console.log('\n' + '-'.repeat(70));
console.log(`总计: ${passed}/${total} 通过 (${((passed/total)*100).toFixed(1)}%)`);
console.log('-'.repeat(70));
if (passed === total) {
console.log('\n🎉 所有测试通过Day 3 AI功能完成\n');
} else if (passed >= total * 0.7) {
console.log(`\n⚠️ 有 ${total - passed} 个测试失败但通过率≥70%,基本可用\n`);
} else {
console.log(`\n❌ 通过率过低,需要调试\n`);
}
}
// 执行测试
runAllTests()
.then(() => {
console.log('测试完成');
process.exit(0);
})
.catch((error) => {
console.error('测试失败:', error);
process.exit(1);
});