feat(dc/tool-c): 完成AI代码生成服务(Day 3 MVP)
核心功能: - 新增AICodeService(550行):AI代码生成核心服务 - 新增AIController(257行):4个API端点 - 新增dc_tool_c_ai_history表:存储对话历史 - 实现自我修正机制:最多3次智能重试 - 集成LLMFactory:复用通用能力层 - 10个Few-shot示例:覆盖Level 1-4场景 技术优化: - 修复NaN序列化问题(Python端转None) - 修复数据传递问题(从Session获取真实数据) - 优化System Prompt(明确环境信息) - 调整Few-shot示例(移除import语句) 测试结果: - 通过率:9/11(81.8%) 达到MVP标准 - 成功场景:缺失值处理、编码、分箱、BMI、筛选、填补、统计、分类 - 待优化:数值清洗、智能去重(已记录技术债务TD-C-006) API端点: - POST /api/v1/dc/tool-c/ai/generate(生成代码) - POST /api/v1/dc/tool-c/ai/execute(执行代码) - POST /api/v1/dc/tool-c/ai/process(生成并执行,一步到位) - GET /api/v1/dc/tool-c/ai/history/:sessionId(对话历史) 文档更新: - 新增Day 3开发完成总结(770行) - 新增复杂场景优化技术债务(TD-C-006) - 更新工具C当前状态文档 - 更新技术债务清单 影响范围: - backend/src/modules/dc/tool-c/*(新增2个文件,更新1个文件) - backend/scripts/create-tool-c-ai-history-table.mjs(新增) - backend/prisma/schema.prisma(新增DcToolCAiHistory模型) - extraction_service/services/dc_executor.py(NaN序列化修复) - docs/03-业务模块/DC-数据清洗整理/*(5份文档更新) Breaking Changes: 无 总代码行数:+950行 Refs: #Tool-C-Day3
This commit is contained in:
@@ -31,3 +31,4 @@ COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.file_key IS 'OSS存储路径: dc/
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.columns IS '列名数组 ["age", "gender", "diagnosis"]';
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.expires_at IS '过期时间(创建后10分钟)';
|
||||
|
||||
|
||||
|
||||
@@ -873,3 +873,30 @@ model DcToolCSession {
|
||||
@@map("dc_tool_c_sessions")
|
||||
@@schema("dc_schema")
|
||||
}
|
||||
|
||||
// Tool C AI对话历史表
|
||||
model DcToolCAiHistory {
|
||||
id String @id @default(uuid())
|
||||
sessionId String @map("session_id") // 关联Tool C Session
|
||||
userId String @map("user_id")
|
||||
role String @map("role") // user/assistant/system
|
||||
content String @db.Text // 消息内容
|
||||
|
||||
// Tool C特有字段
|
||||
generatedCode String? @db.Text @map("generated_code") // AI生成的代码
|
||||
codeExplanation String? @db.Text @map("code_explanation") // 代码解释
|
||||
executeStatus String? @map("execute_status") // pending/success/failed
|
||||
executeResult Json? @map("execute_result") // 执行结果
|
||||
executeError String? @db.Text @map("execute_error") // 错误信息
|
||||
retryCount Int @default(0) @map("retry_count") // 重试次数
|
||||
|
||||
// LLM相关
|
||||
model String? @map("model") // deepseek-v3/qwen3等
|
||||
createdAt DateTime @default(now()) @map("created_at")
|
||||
|
||||
@@index([sessionId])
|
||||
@@index([userId])
|
||||
@@index([createdAt])
|
||||
@@map("dc_tool_c_ai_history")
|
||||
@@schema("dc_schema")
|
||||
}
|
||||
|
||||
@@ -180,3 +180,5 @@ function extractCodeBlocks(obj, blocks = []) {
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -199,3 +199,5 @@ checkDCTables();
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
155
backend/scripts/create-tool-c-ai-history-table.mjs
Normal file
155
backend/scripts/create-tool-c-ai-history-table.mjs
Normal file
@@ -0,0 +1,155 @@
|
||||
/**
|
||||
* 创建 Tool C AI对话历史表
|
||||
*
|
||||
* 执行方式:node scripts/create-tool-c-ai-history-table.mjs
|
||||
*/
|
||||
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
async function createAiHistoryTable() {
|
||||
console.log('========================================');
|
||||
console.log('开始创建 Tool C AI对话历史表');
|
||||
console.log('========================================\n');
|
||||
|
||||
try {
|
||||
// 1. 检查表是否已存在
|
||||
console.log('[1/4] 检查表是否已存在...');
|
||||
const checkResult = await prisma.$queryRawUnsafe(`
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'dc_schema'
|
||||
AND table_name = 'dc_tool_c_ai_history'
|
||||
) as exists
|
||||
`);
|
||||
|
||||
const tableExists = checkResult[0].exists;
|
||||
|
||||
if (tableExists) {
|
||||
console.log('✅ 表已存在: dc_schema.dc_tool_c_ai_history');
|
||||
console.log('\n如需重新创建,请手动执行: DROP TABLE dc_schema.dc_tool_c_ai_history CASCADE;\n');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('✅ 表不存在,准备创建\n');
|
||||
|
||||
// 2. 创建表
|
||||
console.log('[2/4] 创建表 dc_tool_c_ai_history...');
|
||||
await prisma.$executeRawUnsafe(`
|
||||
CREATE TABLE dc_schema.dc_tool_c_ai_history (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
role VARCHAR(50) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
|
||||
-- Tool C特有字段
|
||||
generated_code TEXT,
|
||||
code_explanation TEXT,
|
||||
execute_status VARCHAR(50),
|
||||
execute_result JSONB,
|
||||
execute_error TEXT,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
|
||||
-- LLM相关
|
||||
model VARCHAR(100),
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`);
|
||||
console.log('✅ 表创建成功\n');
|
||||
|
||||
// 3. 创建索引
|
||||
console.log('[3/4] 创建索引...');
|
||||
await prisma.$executeRawUnsafe(`
|
||||
CREATE INDEX idx_dc_tool_c_ai_history_session_id
|
||||
ON dc_schema.dc_tool_c_ai_history(session_id)
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
CREATE INDEX idx_dc_tool_c_ai_history_user_id
|
||||
ON dc_schema.dc_tool_c_ai_history(user_id)
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
CREATE INDEX idx_dc_tool_c_ai_history_created_at
|
||||
ON dc_schema.dc_tool_c_ai_history(created_at)
|
||||
`);
|
||||
console.log('✅ 索引创建成功\n');
|
||||
|
||||
// 4. 添加注释
|
||||
console.log('[4/4] 添加表注释...');
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON TABLE dc_schema.dc_tool_c_ai_history
|
||||
IS 'Tool C (科研数据编辑器) AI对话历史表'
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_ai_history.session_id
|
||||
IS '关联Tool C Session ID'
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_ai_history.generated_code
|
||||
IS 'AI生成的Pandas代码'
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_ai_history.execute_status
|
||||
IS '执行状态: pending/success/failed'
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_ai_history.retry_count
|
||||
IS '自我修正重试次数'
|
||||
`);
|
||||
console.log('✅ 注释添加成功\n');
|
||||
|
||||
// 5. 验证表创建
|
||||
console.log('========================================');
|
||||
console.log('验证表结构');
|
||||
console.log('========================================\n');
|
||||
|
||||
const columns = await prisma.$queryRawUnsafe(`
|
||||
SELECT column_name, data_type, is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'dc_schema'
|
||||
AND table_name = 'dc_tool_c_ai_history'
|
||||
ORDER BY ordinal_position
|
||||
`);
|
||||
|
||||
console.log('表结构:');
|
||||
console.table(columns);
|
||||
|
||||
const indexes = await prisma.$queryRawUnsafe(`
|
||||
SELECT indexname, indexdef
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'dc_schema'
|
||||
AND tablename = 'dc_tool_c_ai_history'
|
||||
`);
|
||||
|
||||
console.log('\n索引:');
|
||||
console.table(indexes);
|
||||
|
||||
console.log('\n========================================');
|
||||
console.log('🎉 Tool C AI对话历史表创建成功!');
|
||||
console.log('========================================\n');
|
||||
console.log('表名: dc_schema.dc_tool_c_ai_history');
|
||||
console.log(`列数: ${columns.length}`);
|
||||
console.log(`索引数: ${indexes.length}\n`);
|
||||
|
||||
} catch (error) {
|
||||
console.error('\n❌ 创建表失败:', error.message);
|
||||
console.error('\n详细错误:');
|
||||
console.error(error);
|
||||
process.exit(1);
|
||||
} finally {
|
||||
await prisma.$disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
// 执行
|
||||
createAiHistoryTable()
|
||||
.then(() => {
|
||||
console.log('脚本执行完成');
|
||||
process.exit(0);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('脚本执行失败:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
|
||||
142
backend/scripts/create-tool-c-table.js
Normal file
142
backend/scripts/create-tool-c-table.js
Normal file
@@ -0,0 +1,142 @@
|
||||
/**
|
||||
* 创建 Tool C Session 表
|
||||
*
|
||||
* 执行方式:node scripts/create-tool-c-table.js
|
||||
*/
|
||||
|
||||
const { PrismaClient } = require('@prisma/client');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
async function createToolCTable() {
|
||||
console.log('========================================');
|
||||
console.log('开始创建 Tool C Session 表');
|
||||
console.log('========================================\n');
|
||||
|
||||
try {
|
||||
// 1. 检查表是否已存在
|
||||
console.log('[1/4] 检查表是否已存在...');
|
||||
const checkResult = await prisma.$queryRawUnsafe(`
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'dc_schema'
|
||||
AND table_name = 'dc_tool_c_sessions'
|
||||
) as exists
|
||||
`);
|
||||
|
||||
const tableExists = checkResult[0].exists;
|
||||
|
||||
if (tableExists) {
|
||||
console.log('✅ 表已存在: dc_schema.dc_tool_c_sessions');
|
||||
console.log('\n是否需要重新创建?(这将删除现有数据)');
|
||||
console.log('如需重新创建,请手动执行: DROP TABLE dc_schema.dc_tool_c_sessions CASCADE;\n');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('✅ 表不存在,准备创建\n');
|
||||
|
||||
// 2. 创建表
|
||||
console.log('[2/4] 创建表 dc_tool_c_sessions...');
|
||||
await prisma.$executeRawUnsafe(`
|
||||
CREATE TABLE dc_schema.dc_tool_c_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
file_name VARCHAR(500) NOT NULL,
|
||||
file_key VARCHAR(500) NOT NULL,
|
||||
|
||||
total_rows INTEGER NOT NULL,
|
||||
total_cols INTEGER NOT NULL,
|
||||
columns JSONB NOT NULL,
|
||||
encoding VARCHAR(50),
|
||||
file_size INTEGER NOT NULL,
|
||||
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL
|
||||
)
|
||||
`);
|
||||
console.log('✅ 表创建成功\n');
|
||||
|
||||
// 3. 创建索引
|
||||
console.log('[3/4] 创建索引...');
|
||||
await prisma.$executeRawUnsafe(`
|
||||
CREATE INDEX idx_dc_tool_c_sessions_user_id ON dc_schema.dc_tool_c_sessions(user_id)
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
CREATE INDEX idx_dc_tool_c_sessions_expires_at ON dc_schema.dc_tool_c_sessions(expires_at)
|
||||
`);
|
||||
console.log('✅ 索引创建成功\n');
|
||||
|
||||
// 4. 添加注释
|
||||
console.log('[4/4] 添加表注释...');
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON TABLE dc_schema.dc_tool_c_sessions IS 'Tool C (科研数据编辑器) Session会话表'
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.file_key IS 'OSS存储路径: dc/tool-c/sessions/{timestamp}-{fileName}'
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.columns IS '列名数组 ["age", "gender", "diagnosis"]'
|
||||
`);
|
||||
await prisma.$executeRawUnsafe(`
|
||||
COMMENT ON COLUMN dc_schema.dc_tool_c_sessions.expires_at IS '过期时间(创建后10分钟)'
|
||||
`);
|
||||
console.log('✅ 注释添加成功\n');
|
||||
|
||||
// 5. 验证表创建
|
||||
console.log('========================================');
|
||||
console.log('验证表结构');
|
||||
console.log('========================================\n');
|
||||
|
||||
const columns = await prisma.$queryRawUnsafe(`
|
||||
SELECT column_name, data_type, is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'dc_schema'
|
||||
AND table_name = 'dc_tool_c_sessions'
|
||||
ORDER BY ordinal_position
|
||||
`);
|
||||
|
||||
console.log('表结构:');
|
||||
console.table(columns);
|
||||
|
||||
const indexes = await prisma.$queryRawUnsafe(`
|
||||
SELECT indexname, indexdef
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'dc_schema'
|
||||
AND tablename = 'dc_tool_c_sessions'
|
||||
`);
|
||||
|
||||
console.log('\n索引:');
|
||||
console.table(indexes);
|
||||
|
||||
console.log('\n========================================');
|
||||
console.log('🎉 Tool C Session 表创建成功!');
|
||||
console.log('========================================\n');
|
||||
console.log('表名: dc_schema.dc_tool_c_sessions');
|
||||
console.log(`列数: ${columns.length}`);
|
||||
console.log(`索引数: ${indexes.length}\n`);
|
||||
|
||||
} catch (error) {
|
||||
console.error('\n❌ 创建表失败:', error.message);
|
||||
console.error('\n详细错误:');
|
||||
console.error(error);
|
||||
process.exit(1);
|
||||
} finally {
|
||||
await prisma.$disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
// 执行
|
||||
createToolCTable()
|
||||
.then(() => {
|
||||
console.log('脚本执行完成');
|
||||
process.exit(0);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('脚本执行失败:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
|
||||
|
||||
@@ -136,3 +136,4 @@ createToolCTable()
|
||||
process.exit(1);
|
||||
});
|
||||
|
||||
|
||||
|
||||
@@ -303,3 +303,5 @@ runTests().catch((error) => {
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -244,3 +244,5 @@ runTest()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -282,3 +282,5 @@ Content-Type: application/json
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -361,3 +361,5 @@ export class ExcelExporter {
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import { FastifyInstance } from 'fastify';
|
||||
import { registerToolBRoutes } from './tool-b/routes/index.js';
|
||||
import { toolCRoutes } from './tool-c/routes/index.js';
|
||||
import { templateService } from './tool-b/services/TemplateService.js';
|
||||
import { logger } from '../../common/logging/index.js';
|
||||
|
||||
@@ -20,6 +21,11 @@ export async function registerDCRoutes(fastify: FastifyInstance) {
|
||||
await registerToolBRoutes(instance);
|
||||
}, { prefix: '/api/v1/dc/tool-b' });
|
||||
|
||||
// 注册Tool C路由(科研数据编辑器)
|
||||
await fastify.register(async (instance) => {
|
||||
await toolCRoutes(instance);
|
||||
}, { prefix: '/api/v1/dc/tool-c' });
|
||||
|
||||
logger.info('[DC] DC module routes registered');
|
||||
}
|
||||
|
||||
|
||||
@@ -218,3 +218,5 @@ export const conflictDetectionService = new ConflictDetectionService();
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -246,3 +246,5 @@ export const templateService = new TemplateService();
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -169,3 +169,4 @@ curl -X POST http://localhost:3000/api/v1/dc/tool-c/test/execute \
|
||||
- [ ] AI代码生成服务(LLMFactory集成)
|
||||
- [ ] 前端基础框架搭建
|
||||
|
||||
|
||||
|
||||
256
backend/src/modules/dc/tool-c/controllers/AIController.ts
Normal file
256
backend/src/modules/dc/tool-c/controllers/AIController.ts
Normal file
@@ -0,0 +1,256 @@
|
||||
/**
|
||||
* AI代码生成控制器
|
||||
*
|
||||
* API端点:
|
||||
* - POST /ai/generate 生成代码(不执行)
|
||||
* - POST /ai/execute 执行代码
|
||||
* - POST /ai/process 生成并执行(带重试)
|
||||
* - GET /ai/history/:sessionId 获取对话历史
|
||||
*
|
||||
* @module AIController
|
||||
*/
|
||||
|
||||
import { FastifyRequest, FastifyReply } from 'fastify';
|
||||
import { logger } from '../../../../common/logging/index.js';
|
||||
import { aiCodeService } from '../services/AICodeService.js';
|
||||
|
||||
// ==================== 请求参数类型定义 ====================
|
||||
|
||||
interface GenerateCodeBody {
|
||||
sessionId: string;
|
||||
message: string;
|
||||
}
|
||||
|
||||
interface ExecuteCodeBody {
|
||||
sessionId: string;
|
||||
code: string;
|
||||
messageId: string;
|
||||
}
|
||||
|
||||
interface ProcessBody {
|
||||
sessionId: string;
|
||||
message: string;
|
||||
maxRetries?: number;
|
||||
}
|
||||
|
||||
interface HistoryParams {
|
||||
sessionId: string;
|
||||
}
|
||||
|
||||
// ==================== 控制器 ====================
|
||||
|
||||
export class AIController {
|
||||
|
||||
/**
|
||||
* POST /api/v1/dc/tool-c/ai/generate
|
||||
* 生成代码(不执行)
|
||||
*/
|
||||
async generateCode(request: FastifyRequest, reply: FastifyReply) {
|
||||
try {
|
||||
const { sessionId, message } = request.body as GenerateCodeBody;
|
||||
|
||||
logger.info(`[AIController] 收到生成代码请求: sessionId=${sessionId}`);
|
||||
|
||||
// 参数验证
|
||||
if (!sessionId || !message) {
|
||||
return reply.code(400).send({
|
||||
success: false,
|
||||
error: '缺少必要参数:sessionId 或 message'
|
||||
});
|
||||
}
|
||||
|
||||
if (message.trim().length === 0) {
|
||||
return reply.code(400).send({
|
||||
success: false,
|
||||
error: '消息内容不能为空'
|
||||
});
|
||||
}
|
||||
|
||||
// 生成代码
|
||||
const result = await aiCodeService.generateCode(sessionId, message);
|
||||
|
||||
logger.info(`[AIController] 代码生成成功: messageId=${result.messageId}`);
|
||||
|
||||
return reply.code(200).send({
|
||||
success: true,
|
||||
message: 'AI代码生成成功',
|
||||
data: {
|
||||
code: result.code,
|
||||
explanation: result.explanation,
|
||||
messageId: result.messageId
|
||||
}
|
||||
});
|
||||
} catch (error: any) {
|
||||
logger.error(`[AIController] generateCode失败: ${error.message}`);
|
||||
return reply.code(500).send({
|
||||
success: false,
|
||||
error: error.message || 'AI代码生成失败,请重试'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* POST /api/v1/dc/tool-c/ai/execute
|
||||
* 执行代码
|
||||
*/
|
||||
async executeCode(request: FastifyRequest, reply: FastifyReply) {
|
||||
try {
|
||||
const { sessionId, code, messageId } = request.body as ExecuteCodeBody;
|
||||
|
||||
logger.info(`[AIController] 收到执行代码请求: messageId=${messageId}`);
|
||||
|
||||
// 参数验证
|
||||
if (!sessionId || !code || !messageId) {
|
||||
return reply.code(400).send({
|
||||
success: false,
|
||||
error: '缺少必要参数:sessionId、code 或 messageId'
|
||||
});
|
||||
}
|
||||
|
||||
// 执行代码
|
||||
const result = await aiCodeService.executeCode(sessionId, code, messageId);
|
||||
|
||||
if (result.success) {
|
||||
logger.info(`[AIController] 代码执行成功: messageId=${messageId}`);
|
||||
return reply.code(200).send({
|
||||
success: true,
|
||||
message: '代码执行成功',
|
||||
data: {
|
||||
result: result.result,
|
||||
newDataPreview: result.newDataPreview
|
||||
}
|
||||
});
|
||||
} else {
|
||||
logger.warn(`[AIController] 代码执行失败: ${result.error}`);
|
||||
return reply.code(200).send({
|
||||
success: false,
|
||||
error: result.error || '代码执行失败',
|
||||
data: null
|
||||
});
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error(`[AIController] executeCode失败: ${error.message}`);
|
||||
return reply.code(500).send({
|
||||
success: false,
|
||||
error: error.message || '代码执行失败,请重试'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* POST /api/v1/dc/tool-c/ai/process
|
||||
* 生成并执行(一步到位,带重试)
|
||||
*/
|
||||
async process(request: FastifyRequest, reply: FastifyReply) {
|
||||
try {
|
||||
const { sessionId, message, maxRetries = 3 } = request.body as ProcessBody;
|
||||
|
||||
logger.info(`[AIController] 收到处理请求: sessionId=${sessionId}, maxRetries=${maxRetries}`);
|
||||
|
||||
// 参数验证
|
||||
if (!sessionId || !message) {
|
||||
return reply.code(400).send({
|
||||
success: false,
|
||||
error: '缺少必要参数:sessionId 或 message'
|
||||
});
|
||||
}
|
||||
|
||||
if (message.trim().length === 0) {
|
||||
return reply.code(400).send({
|
||||
success: false,
|
||||
error: '消息内容不能为空'
|
||||
});
|
||||
}
|
||||
|
||||
if (maxRetries < 1 || maxRetries > 5) {
|
||||
return reply.code(400).send({
|
||||
success: false,
|
||||
error: '重试次数必须在1-5之间'
|
||||
});
|
||||
}
|
||||
|
||||
// 生成并执行(带重试)
|
||||
const result = await aiCodeService.generateAndExecute(
|
||||
sessionId,
|
||||
message,
|
||||
maxRetries
|
||||
);
|
||||
|
||||
logger.info(`[AIController] 处理成功: 重试${result.retryCount}次后成功`);
|
||||
|
||||
return reply.code(200).send({
|
||||
success: true,
|
||||
message: `代码执行成功${result.retryCount > 0 ? `(重试${result.retryCount}次)` : ''}`,
|
||||
data: {
|
||||
code: result.code,
|
||||
explanation: result.explanation,
|
||||
executeResult: result.executeResult,
|
||||
retryCount: result.retryCount,
|
||||
messageId: result.messageId
|
||||
}
|
||||
});
|
||||
} catch (error: any) {
|
||||
logger.error(`[AIController] process失败: ${error.message}`);
|
||||
return reply.code(500).send({
|
||||
success: false,
|
||||
error: error.message || '处理失败,请重试'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* GET /api/v1/dc/tool-c/ai/history/:sessionId
|
||||
* 获取对话历史
|
||||
*/
|
||||
async getHistory(
|
||||
request: FastifyRequest<{ Params: HistoryParams; Querystring: { limit?: string } }>,
|
||||
reply: FastifyReply
|
||||
) {
|
||||
try {
|
||||
const { sessionId } = request.params;
|
||||
const limit = request.query.limit ? parseInt(request.query.limit) : 10;
|
||||
|
||||
logger.info(`[AIController] 获取对话历史: sessionId=${sessionId}, limit=${limit}`);
|
||||
|
||||
// 参数验证
|
||||
if (!sessionId) {
|
||||
return reply.code(400).send({
|
||||
success: false,
|
||||
error: '缺少必要参数:sessionId'
|
||||
});
|
||||
}
|
||||
|
||||
if (limit < 1 || limit > 50) {
|
||||
return reply.code(400).send({
|
||||
success: false,
|
||||
error: '历史记录数量必须在1-50之间'
|
||||
});
|
||||
}
|
||||
|
||||
// 获取历史
|
||||
const history = await aiCodeService.getHistory(sessionId, limit);
|
||||
|
||||
logger.info(`[AIController] 获取历史成功: ${history.length}条`);
|
||||
|
||||
return reply.code(200).send({
|
||||
success: true,
|
||||
data: {
|
||||
sessionId,
|
||||
history,
|
||||
count: history.length
|
||||
}
|
||||
});
|
||||
} catch (error: any) {
|
||||
logger.error(`[AIController] getHistory失败: ${error.message}`);
|
||||
return reply.code(500).send({
|
||||
success: false,
|
||||
error: error.message || '获取对话历史失败'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 导出单例实例 ====================
|
||||
|
||||
export const aiController = new AIController();
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import { FastifyInstance } from 'fastify';
|
||||
import { testController } from '../controllers/TestController.js';
|
||||
import { sessionController } from '../controllers/SessionController.js';
|
||||
import { aiController } from '../controllers/AIController.js';
|
||||
|
||||
export async function toolCRoutes(fastify: FastifyInstance) {
|
||||
// ==================== 测试路由(Day 1) ====================
|
||||
@@ -57,5 +58,27 @@ export async function toolCRoutes(fastify: FastifyInstance) {
|
||||
fastify.post('/sessions/:id/heartbeat', {
|
||||
handler: sessionController.updateHeartbeat.bind(sessionController),
|
||||
});
|
||||
|
||||
// ==================== AI代码生成路由(Day 3) ====================
|
||||
|
||||
// 生成代码(不执行)
|
||||
fastify.post('/ai/generate', {
|
||||
handler: aiController.generateCode.bind(aiController),
|
||||
});
|
||||
|
||||
// 执行代码
|
||||
fastify.post('/ai/execute', {
|
||||
handler: aiController.executeCode.bind(aiController),
|
||||
});
|
||||
|
||||
// 生成并执行(一步到位,带重试)
|
||||
fastify.post('/ai/process', {
|
||||
handler: aiController.process.bind(aiController),
|
||||
});
|
||||
|
||||
// 获取对话历史
|
||||
fastify.get('/ai/history/:sessionId', {
|
||||
handler: aiController.getHistory.bind(aiController),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
549
backend/src/modules/dc/tool-c/services/AICodeService.ts
Normal file
549
backend/src/modules/dc/tool-c/services/AICodeService.ts
Normal file
@@ -0,0 +1,549 @@
|
||||
/**
|
||||
* AI代码生成服务
|
||||
*
|
||||
* 功能:
|
||||
* - 使用LLM生成Pandas数据清洗代码
|
||||
* - 执行生成的代码
|
||||
* - 自我修正(最多3次重试)
|
||||
* - 管理对话历史
|
||||
*
|
||||
* @module AICodeService
|
||||
*/
|
||||
|
||||
import { logger } from '../../../../common/logging/index.js';
|
||||
import { prisma } from '../../../../config/database.js';
|
||||
import { LLMFactory } from '../../../../common/llm/adapters/LLMFactory.js';
|
||||
import { ModelType, Message } from '../../../../common/llm/adapters/types.js';
|
||||
import { sessionService } from './SessionService.js';
|
||||
import { pythonExecutorService } from './PythonExecutorService.js';
|
||||
|
||||
// ==================== 类型定义 ====================
|
||||
|
||||
interface SessionData {
|
||||
id: string;
|
||||
fileName: string;
|
||||
totalRows: number;
|
||||
totalCols: number;
|
||||
columns: string[];
|
||||
}
|
||||
|
||||
interface GenerateCodeResult {
|
||||
code: string;
|
||||
explanation: string;
|
||||
messageId: string;
|
||||
}
|
||||
|
||||
interface ExecuteCodeResult {
|
||||
success: boolean;
|
||||
result?: any;
|
||||
error?: string;
|
||||
newDataPreview?: any[];
|
||||
}
|
||||
|
||||
interface ProcessResult extends GenerateCodeResult {
|
||||
executeResult: ExecuteCodeResult;
|
||||
retryCount: number;
|
||||
}
|
||||
|
||||
// ==================== AI代码生成服务 ====================
|
||||
|
||||
export class AICodeService {
|
||||
|
||||
/**
|
||||
* 生成Pandas代码
|
||||
* @param sessionId - Tool C Session ID
|
||||
* @param userMessage - 用户自然语言需求
|
||||
* @returns { code, explanation, messageId }
|
||||
*/
|
||||
async generateCode(
|
||||
sessionId: string,
|
||||
userMessage: string
|
||||
): Promise<GenerateCodeResult> {
|
||||
try {
|
||||
logger.info(`[AICodeService] 生成代码: sessionId=${sessionId}`);
|
||||
|
||||
// 1. 获取Session信息(数据集元数据)
|
||||
const session = await sessionService.getSession(sessionId);
|
||||
|
||||
// 2. 构建System Prompt(含10个Few-shot示例)
|
||||
const systemPrompt = this.buildSystemPrompt({
|
||||
id: session.id,
|
||||
fileName: session.fileName,
|
||||
totalRows: session.totalRows,
|
||||
totalCols: session.totalCols,
|
||||
columns: session.columns
|
||||
});
|
||||
|
||||
// 3. 获取对话历史(最近5轮)
|
||||
const history = await this.getHistory(sessionId, 5);
|
||||
|
||||
// 4. 调用LLM(复用LLMFactory)
|
||||
const llm = LLMFactory.getAdapter('deepseek-v3' as ModelType);
|
||||
const response = await llm.chat([
|
||||
{ role: 'system', content: systemPrompt },
|
||||
...history,
|
||||
{ role: 'user', content: userMessage }
|
||||
], {
|
||||
temperature: 0.1, // 低温度,确保代码准确性
|
||||
maxTokens: 2000, // 足够生成代码+解释
|
||||
topP: 0.9
|
||||
});
|
||||
|
||||
logger.info(`[AICodeService] LLM响应成功,开始解析...`);
|
||||
|
||||
// 5. 解析AI回复(提取code和explanation)
|
||||
const parsed = this.parseAIResponse(response.content);
|
||||
|
||||
// 6. 保存到数据库
|
||||
const messageId = await this.saveMessages(
|
||||
sessionId,
|
||||
session.userId,
|
||||
userMessage,
|
||||
parsed.code,
|
||||
parsed.explanation
|
||||
);
|
||||
|
||||
logger.info(`[AICodeService] 代码生成成功: messageId=${messageId}`);
|
||||
|
||||
return {
|
||||
code: parsed.code,
|
||||
explanation: parsed.explanation,
|
||||
messageId
|
||||
};
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 生成代码失败: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行Python代码
|
||||
* @param sessionId - Tool C Session ID
|
||||
* @param code - Python代码
|
||||
* @param messageId - 关联的消息ID
|
||||
* @returns { success, result, newDataPreview }
|
||||
*/
|
||||
async executeCode(
|
||||
sessionId: string,
|
||||
code: string,
|
||||
messageId: string
|
||||
): Promise<ExecuteCodeResult> {
|
||||
try {
|
||||
logger.info(`[AICodeService] 执行代码: messageId=${messageId}`);
|
||||
|
||||
// 1. 从Session获取完整数据
|
||||
const fullData = await sessionService.getFullData(sessionId);
|
||||
logger.info(`[AICodeService] 获取Session数据: ${fullData.length}行`);
|
||||
|
||||
// 2. 调用Python服务执行
|
||||
const result = await pythonExecutorService.executeCode(
|
||||
fullData,
|
||||
code
|
||||
);
|
||||
|
||||
// 2. 更新消息状态
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
await prisma.dcToolCAiHistory.update({
|
||||
where: { id: messageId },
|
||||
data: {
|
||||
executeStatus: result.success ? 'success' : 'failed',
|
||||
executeResult: result.result_data ? JSON.parse(JSON.stringify({ data: result.result_data })) : undefined,
|
||||
executeError: result.error || undefined
|
||||
}
|
||||
});
|
||||
|
||||
// 4. 如果成功,获取新数据预览(前50行)
|
||||
if (result.success && result.result_data) {
|
||||
const preview = Array.isArray(result.result_data)
|
||||
? result.result_data.slice(0, 50)
|
||||
: result.result_data;
|
||||
|
||||
logger.info(`[AICodeService] 代码执行成功`);
|
||||
|
||||
return {
|
||||
success: true,
|
||||
result: result.result_data,
|
||||
newDataPreview: preview
|
||||
};
|
||||
}
|
||||
|
||||
logger.warn(`[AICodeService] 代码执行失败: ${result.error}`);
|
||||
|
||||
return {
|
||||
success: false,
|
||||
error: result.error || '执行失败,未知错误'
|
||||
};
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 执行代码异常: ${error.message}`);
|
||||
|
||||
// 更新为失败状态
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
await prisma.dcToolCAiHistory.update({
|
||||
where: { id: messageId },
|
||||
data: {
|
||||
executeStatus: 'failed',
|
||||
executeError: error.message
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
success: false,
|
||||
error: error.message
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成并执行(带自我修正)
|
||||
* @param sessionId - Tool C Session ID
|
||||
* @param userMessage - 用户需求
|
||||
* @param maxRetries - 最大重试次数(默认3)
|
||||
* @returns { code, explanation, executeResult, retryCount }
|
||||
*/
|
||||
async generateAndExecute(
|
||||
sessionId: string,
|
||||
userMessage: string,
|
||||
maxRetries: number = 3
|
||||
): Promise<ProcessResult> {
|
||||
let attempt = 0;
|
||||
let lastError: string | null = null;
|
||||
let generated: GenerateCodeResult | null = null;
|
||||
|
||||
while (attempt < maxRetries) {
|
||||
try {
|
||||
logger.info(`[AICodeService] 尝试 ${attempt + 1}/${maxRetries}`);
|
||||
|
||||
// 构建带错误反馈的提示词
|
||||
const enhancedMessage = attempt === 0
|
||||
? userMessage
|
||||
: `${userMessage}\n\n上次执行错误:${lastError}\n请修正代码,确保代码正确且符合Pandas语法。`;
|
||||
|
||||
// 生成代码
|
||||
generated = await this.generateCode(sessionId, enhancedMessage);
|
||||
|
||||
// 执行代码
|
||||
const executeResult = await this.executeCode(
|
||||
sessionId,
|
||||
generated.code,
|
||||
generated.messageId
|
||||
);
|
||||
|
||||
if (executeResult.success) {
|
||||
// ✅ 成功
|
||||
logger.info(`[AICodeService] 执行成功(尝试${attempt + 1}次)`);
|
||||
|
||||
// 更新重试次数
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
await prisma.dcToolCAiHistory.update({
|
||||
where: { id: generated.messageId },
|
||||
data: { retryCount: attempt }
|
||||
});
|
||||
|
||||
return {
|
||||
...generated,
|
||||
executeResult,
|
||||
retryCount: attempt
|
||||
};
|
||||
}
|
||||
|
||||
// ❌ 失败,准备重试
|
||||
lastError = executeResult.error || '未知错误';
|
||||
attempt++;
|
||||
|
||||
logger.warn(`[AICodeService] 执行失败(尝试${attempt}/${maxRetries}): ${lastError}`);
|
||||
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 异常: ${error.message}`);
|
||||
lastError = error.message;
|
||||
attempt++;
|
||||
}
|
||||
}
|
||||
|
||||
// 3次仍失败
|
||||
throw new Error(
|
||||
`代码执行失败(已重试${maxRetries}次)。最后错误:${lastError}。` +
|
||||
`建议:请调整需求描述或检查数据列名是否正确。`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取对话历史
|
||||
* @param sessionId - Tool C Session ID
|
||||
* @param limit - 最近N轮对话(默认5轮,即10条消息)
|
||||
* @returns 消息列表
|
||||
*/
|
||||
async getHistory(sessionId: string, limit: number = 5): Promise<Message[]> {
|
||||
try {
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
const records = await prisma.dcToolCAiHistory.findMany({
|
||||
where: { sessionId },
|
||||
orderBy: { createdAt: 'desc' },
|
||||
take: limit * 2, // user + assistant 成对
|
||||
select: {
|
||||
role: true,
|
||||
content: true
|
||||
}
|
||||
});
|
||||
|
||||
// 反转顺序(最旧的在前)
|
||||
return records.reverse().map((r: any) => ({
|
||||
role: r.role as 'user' | 'assistant' | 'system',
|
||||
content: r.content
|
||||
}));
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 获取历史失败: ${error.message}`);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
/**
|
||||
* 构建System Prompt(含10个Few-shot示例)
|
||||
* @private
|
||||
*/
|
||||
private buildSystemPrompt(session: SessionData): string {
|
||||
return `你是医疗科研数据清洗专家,负责生成Pandas代码来清洗整理数据。
|
||||
|
||||
## 当前数据集信息
|
||||
- 文件名: ${session.fileName}
|
||||
- 行数: ${session.totalRows}
|
||||
- 列数: ${session.totalCols}
|
||||
- 列名: ${session.columns.join(', ')}
|
||||
|
||||
## 执行环境(重要)
|
||||
**已预导入的库(请直接使用,不要再import):**
|
||||
- pandas 已导入为 pd
|
||||
- numpy 已导入为 np
|
||||
- df 变量已加载当前数据集
|
||||
|
||||
**不可用的库:**
|
||||
- sklearn(未安装,请使用pandas/numpy替代方案)
|
||||
- scipy(未安装)
|
||||
- 其他第三方库
|
||||
|
||||
**示例:直接使用,无需导入**
|
||||
\`\`\`python
|
||||
# ✅ 正确:直接使用预导入的库
|
||||
df['age_clean'] = df['age'].fillna(df['age'].median())
|
||||
df['group'] = np.where(df['age'] > 60, '老年', '非老年')
|
||||
|
||||
# ❌ 错误:不要再导入
|
||||
import pandas as pd # 会报错!
|
||||
import numpy as np # 会报错!
|
||||
import sklearn # 未安装,会报错!
|
||||
\`\`\`
|
||||
|
||||
## 安全规则(强制)
|
||||
1. 只能操作df变量,不能修改其他变量
|
||||
2. **禁止任何import语句**(pandas和numpy已预导入)
|
||||
3. 禁止使用eval、exec、__import__等危险函数
|
||||
4. 必须进行异常处理
|
||||
5. 返回格式必须是JSON: {"code": "...", "explanation": "..."}
|
||||
|
||||
## Few-shot示例
|
||||
|
||||
### 示例1: 统一缺失值标记
|
||||
用户: 把所有代表缺失的符号(-、不详、NA、N/A)统一替换为标准空值
|
||||
代码:
|
||||
\`\`\`python
|
||||
try:
|
||||
df = df.replace(['-', '不详', 'NA', 'N/A', '\\\\', '未查'], np.nan)
|
||||
print(f'缺失值标记统一完成,当前缺失值数量: {df.isna().sum().sum()}')
|
||||
except Exception as e:
|
||||
print(f'处理错误: {e}')
|
||||
\`\`\`
|
||||
说明: 将多种缺失值表示统一为NaN,便于后续统计分析
|
||||
|
||||
### 示例2: 数值列清洗
|
||||
用户: 把肌酐列里的非数字符号去掉,<0.1按0.05处理,转为数值类型
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['creatinine'] = df['creatinine'].astype(str).str.replace('>', '').str.replace('<', '')
|
||||
df.loc[df['creatinine'] == '0.1', 'creatinine'] = '0.05'
|
||||
df['creatinine'] = pd.to_numeric(df['creatinine'], errors='coerce')
|
||||
\`\`\`
|
||||
说明: 检验科数据常含符号,需清理后才能计算
|
||||
|
||||
### 示例3: 分类变量编码
|
||||
用户: 把性别列转为数字,男=1,女=0
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['gender_code'] = df['gender'].map({'男': 1, '女': 0})
|
||||
\`\`\`
|
||||
说明: 将文本分类变量转为数值,便于统计建模
|
||||
|
||||
### 示例4: 连续变量分箱
|
||||
用户: 把年龄按18岁、60岁分为未成年、成年、老年三组
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['age_group'] = pd.cut(df['age'], bins=[0, 18, 60, 120], labels=['未成年', '成年', '老年'], right=False)
|
||||
\`\`\`
|
||||
说明: 将连续变量离散化,用于分层分析或卡方检验
|
||||
|
||||
### 示例5: BMI计算与分类
|
||||
用户: 根据身高(cm)和体重(kg)计算BMI,并标记BMI≥28为肥胖
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['BMI'] = df['weight'] / (df['height'] / 100) ** 2
|
||||
df['obesity'] = df['BMI'].apply(lambda x: '肥胖' if x >= 28 else '正常')
|
||||
\`\`\`
|
||||
说明: 临床常用的体质指标计算和分类
|
||||
|
||||
### 示例6: 日期计算
|
||||
用户: 根据入院日期和出院日期计算住院天数
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['admission_date'] = pd.to_datetime(df['admission_date'])
|
||||
df['discharge_date'] = pd.to_datetime(df['discharge_date'])
|
||||
df['length_of_stay'] = (df['discharge_date'] - df['admission_date']).dt.days
|
||||
\`\`\`
|
||||
说明: 医疗数据常需计算时间间隔(住院天数、随访时间等)
|
||||
|
||||
### 示例7: 条件筛选(入组标准)
|
||||
用户: 筛选出年龄≥18岁、诊断为糖尿病、且血糖≥7.0的患者
|
||||
代码:
|
||||
\`\`\`python
|
||||
df_selected = df[(df['age'] >= 18) & (df['diagnosis'] == '糖尿病') & (df['glucose'] >= 7.0)]
|
||||
\`\`\`
|
||||
说明: 临床研究常需根据入组/排除标准筛选病例
|
||||
|
||||
### 示例8: 简单缺失值填补
|
||||
用户: 用中位数填补BMI列的缺失值
|
||||
代码:
|
||||
\`\`\`python
|
||||
bmi_median = df['BMI'].median()
|
||||
df['BMI'] = df['BMI'].fillna(bmi_median)
|
||||
\`\`\`
|
||||
说明: 简单填补适用于缺失率<5%且MCAR(完全随机缺失)的情况
|
||||
|
||||
### 示例9: 智能多列缺失值填补
|
||||
用户: 对BMI、年龄、肌酐列的缺失值进行智能填补
|
||||
代码:
|
||||
\`\`\`python
|
||||
try:
|
||||
# 检查列是否存在
|
||||
cols = ['BMI', 'age', 'creatinine']
|
||||
missing_cols = [c for c in cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f'警告:以下列不存在: {missing_cols}')
|
||||
else:
|
||||
# 转换为数值类型
|
||||
for col in cols:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
|
||||
# 根据列特性选择填补策略
|
||||
df['age'] = df['age'].fillna(df['age'].median()) # 年龄用中位数
|
||||
df['BMI'] = df['BMI'].fillna(df.groupby('gender')['BMI'].transform('median')) # BMI按性别分组填补
|
||||
df['creatinine'] = df['creatinine'].fillna(df['creatinine'].mean()) # 肌酐用均值
|
||||
|
||||
print('缺失值填补完成')
|
||||
print(f'年龄缺失: {df["age"].isna().sum()}')
|
||||
print(f'BMI缺失: {df["BMI"].isna().sum()}')
|
||||
print(f'肌酐缺失: {df["creatinine"].isna().sum()}')
|
||||
except Exception as e:
|
||||
print(f'填补错误: {e}')
|
||||
\`\`\`
|
||||
说明: 根据医学变量特性选择不同填补策略:年龄用中位数(稳健),BMI按性别分组(考虑性别差异),肌酐用均值
|
||||
|
||||
### 示例10: 智能去重
|
||||
用户: 按患者ID去重,保留检查日期最新的记录
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['check_date'] = pd.to_datetime(df['check_date'])
|
||||
df = df.sort_values('check_date').drop_duplicates(subset=['patient_id'], keep='last')
|
||||
\`\`\`
|
||||
说明: 先按日期排序,再去重保留最后一条(最新)
|
||||
|
||||
## 用户当前请求
|
||||
请根据以上示例和当前数据集信息,生成代码并解释。返回JSON格式:{"code": "...", "explanation": "..."}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析AI回复(提取code和explanation)
|
||||
* @private
|
||||
*/
|
||||
private parseAIResponse(content: string): { code: string; explanation: string } {
|
||||
try {
|
||||
// 方法1:尝试解析JSON
|
||||
const json = JSON.parse(content);
|
||||
if (json.code && json.explanation) {
|
||||
return { code: json.code, explanation: json.explanation };
|
||||
}
|
||||
} catch {
|
||||
// 方法2:正则提取代码块
|
||||
const codeMatch = content.match(/```python\n([\s\S]+?)\n```/);
|
||||
const code = codeMatch ? codeMatch[1].trim() : '';
|
||||
|
||||
// 提取解释(代码块之外的文本)
|
||||
let explanation = content.replace(/```python[\s\S]+?```/g, '').trim();
|
||||
|
||||
// 如果没有单独的解释,尝试提取JSON中的explanation
|
||||
try {
|
||||
const jsonMatch = content.match(/\{[\s\S]*"explanation":\s*"([^"]+)"[\s\S]*\}/);
|
||||
if (jsonMatch) {
|
||||
explanation = jsonMatch[1];
|
||||
}
|
||||
} catch {
|
||||
// 忽略
|
||||
}
|
||||
|
||||
if (code) {
|
||||
return { code, explanation: explanation || '代码已生成' };
|
||||
}
|
||||
}
|
||||
|
||||
logger.error(`[AICodeService] AI回复格式错误: ${content.substring(0, 200)}`);
|
||||
throw new Error('AI回复格式错误,无法提取代码。请重试。');
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存消息到数据库
|
||||
* @private
|
||||
*/
|
||||
private async saveMessages(
|
||||
sessionId: string,
|
||||
userId: string,
|
||||
userMessage: string,
|
||||
code: string,
|
||||
explanation: string
|
||||
): Promise<string> {
|
||||
try {
|
||||
// 保存用户消息
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
await prisma.dcToolCAiHistory.create({
|
||||
data: {
|
||||
sessionId,
|
||||
userId,
|
||||
role: 'user',
|
||||
content: userMessage
|
||||
}
|
||||
});
|
||||
|
||||
// 保存AI回复
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
const assistantMessage = await prisma.dcToolCAiHistory.create({
|
||||
data: {
|
||||
sessionId,
|
||||
userId,
|
||||
role: 'assistant',
|
||||
content: explanation,
|
||||
generatedCode: code,
|
||||
codeExplanation: explanation,
|
||||
executeStatus: 'pending',
|
||||
model: 'deepseek-v3'
|
||||
}
|
||||
});
|
||||
|
||||
return assistantMessage.id;
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 保存消息失败: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 导出单例实例 ====================
|
||||
|
||||
export const aiCodeService = new AICodeService();
|
||||
|
||||
@@ -26,3 +26,5 @@ Write-Host "✅ 完成!" -ForegroundColor Green
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -380,3 +380,4 @@ runAllTests()
|
||||
process.exit(1);
|
||||
});
|
||||
|
||||
|
||||
|
||||
341
backend/test-tool-c-day3.mjs
Normal file
341
backend/test-tool-c-day3.mjs
Normal file
@@ -0,0 +1,341 @@
|
||||
/**
|
||||
* Tool C Day 3 测试脚本
|
||||
*
|
||||
* 测试内容:
|
||||
* 1. 10个Few-shot示例场景测试
|
||||
* 2. AI自我修正机制测试(重试)
|
||||
* 3. 端到端测试
|
||||
*
|
||||
* 前提:
|
||||
* - 需要先创建一个Session(使用Day 2的upload接口)
|
||||
* - 需要Python服务运行(端口8000)
|
||||
* - 需要后端服务运行(端口3000)
|
||||
*
|
||||
* 执行方式:node test-tool-c-day3.mjs
|
||||
*/
|
||||
|
||||
import axios from 'axios';
|
||||
import FormData from 'form-data';
|
||||
import * as XLSX from 'xlsx';
|
||||
|
||||
const BASE_URL = 'http://localhost:3000';
|
||||
const API_PREFIX = '/api/v1/dc/tool-c';
|
||||
|
||||
let testSessionId = null;
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
function printSection(title) {
|
||||
console.log('\n' + '='.repeat(70));
|
||||
console.log(` ${title}`);
|
||||
console.log('='.repeat(70) + '\n');
|
||||
}
|
||||
|
||||
function printSuccess(message) {
|
||||
console.log('✅ ' + message);
|
||||
}
|
||||
|
||||
function printError(message) {
|
||||
console.log('❌ ' + message);
|
||||
}
|
||||
|
||||
function printInfo(message) {
|
||||
console.log('ℹ️ ' + message);
|
||||
}
|
||||
|
||||
// ==================== 准备测试Session ====================
|
||||
|
||||
async function createTestSession() {
|
||||
printSection('准备:创建测试Session');
|
||||
|
||||
try {
|
||||
// 创建测试Excel数据
|
||||
const testData = [
|
||||
{ patient_id: 'P001', name: '张三', age: 25, gender: '男', diagnosis: '感冒', sbp: 120, dbp: 80, weight: 70, height: 175, BMI: '', creatinine: '>100', check_date: '2024-01-01' },
|
||||
{ patient_id: 'P002', name: '李四', age: 65, gender: '女', diagnosis: '高血压', sbp: 150, dbp: 95, weight: 65, height: 160, BMI: '', creatinine: '<0.1', check_date: '2024-01-05' },
|
||||
{ patient_id: 'P003', name: '王五', age: 45, gender: '男', diagnosis: '糖尿病', sbp: 135, dbp: 85, weight: 80, height: 170, BMI: '', creatinine: '85', check_date: '2024-01-03' },
|
||||
{ patient_id: 'P004', name: '赵六', age: 70, gender: '女', diagnosis: '冠心病', sbp: 160, dbp: 100, weight: 60, height: 155, BMI: '', creatinine: '120', check_date: '2024-01-10' },
|
||||
{ patient_id: 'P005', name: '钱七', age: 35, gender: '男', diagnosis: '胃炎', sbp: 110, dbp: 70, weight: 75, height: 180, BMI: '', creatinine: '-', check_date: '2024-01-08' },
|
||||
{ patient_id: 'P003', name: '王五', age: 45, gender: '男', diagnosis: '糖尿病', sbp: 138, dbp: 88, weight: 82, height: 170, BMI: '', creatinine: '88', check_date: '2024-01-12' }, // 重复ID,日期更新
|
||||
];
|
||||
|
||||
// 生成Excel
|
||||
const ws = XLSX.utils.json_to_sheet(testData);
|
||||
const wb = XLSX.utils.book_new();
|
||||
XLSX.utils.book_append_sheet(wb, ws, 'Sheet1');
|
||||
const excelBuffer = XLSX.write(wb, { type: 'buffer', bookType: 'xlsx' });
|
||||
|
||||
// 上传创建Session
|
||||
const form = new FormData();
|
||||
form.append('file', excelBuffer, {
|
||||
filename: 'test-medical-data-day3.xlsx',
|
||||
contentType: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
});
|
||||
|
||||
const response = await axios.post(
|
||||
`${BASE_URL}${API_PREFIX}/sessions/upload`,
|
||||
form,
|
||||
{ headers: form.getHeaders(), timeout: 10000 }
|
||||
);
|
||||
|
||||
if (response.data.success) {
|
||||
testSessionId = response.data.data.sessionId;
|
||||
printSuccess(`Session创建成功: ${testSessionId}`);
|
||||
printInfo(`数据: ${testData.length}行 x ${Object.keys(testData[0]).length}列`);
|
||||
return true;
|
||||
} else {
|
||||
printError('Session创建失败');
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
printError('Session创建异常: ' + error.message);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== AI测试函数 ====================
|
||||
|
||||
async function testAIGenerate(testName, userMessage, shouldSucceed = true) {
|
||||
printSection(`测试: ${testName}`);
|
||||
|
||||
try {
|
||||
printInfo(`用户需求: ${userMessage}`);
|
||||
|
||||
const response = await axios.post(
|
||||
`${BASE_URL}${API_PREFIX}/ai/generate`,
|
||||
{
|
||||
sessionId: testSessionId,
|
||||
message: userMessage
|
||||
},
|
||||
{ timeout: 30000 } // AI调用可能需要较长时间
|
||||
);
|
||||
|
||||
if (response.data.success) {
|
||||
printSuccess('AI生成代码成功');
|
||||
console.log('\n生成的代码:');
|
||||
console.log('```python');
|
||||
console.log(response.data.data.code);
|
||||
console.log('```\n');
|
||||
console.log('解释:', response.data.data.explanation);
|
||||
console.log('MessageID:', response.data.data.messageId);
|
||||
return { success: true, data: response.data.data };
|
||||
} else {
|
||||
if (shouldSucceed) {
|
||||
printError('AI生成失败: ' + response.data.error);
|
||||
} else {
|
||||
printInfo('预期失败: ' + response.data.error);
|
||||
}
|
||||
return { success: false, error: response.data.error };
|
||||
}
|
||||
} catch (error) {
|
||||
printError('AI生成异常: ' + (error.response?.data?.error || error.message));
|
||||
return { success: false, error: error.message };
|
||||
}
|
||||
}
|
||||
|
||||
async function testAIProcess(testName, userMessage) {
|
||||
printSection(`测试(一步到位): ${testName}`);
|
||||
|
||||
try {
|
||||
printInfo(`用户需求: ${userMessage}`);
|
||||
|
||||
const response = await axios.post(
|
||||
`${BASE_URL}${API_PREFIX}/ai/process`,
|
||||
{
|
||||
sessionId: testSessionId,
|
||||
message: userMessage,
|
||||
maxRetries: 3
|
||||
},
|
||||
{ timeout: 60000 } // 带重试可能需要更长时间
|
||||
);
|
||||
|
||||
if (response.data.success) {
|
||||
printSuccess(`执行成功${response.data.data.retryCount > 0 ? `(重试${response.data.data.retryCount}次)` : ''}`);
|
||||
console.log('\n生成的代码:');
|
||||
console.log('```python');
|
||||
console.log(response.data.data.code);
|
||||
console.log('```\n');
|
||||
console.log('解释:', response.data.data.explanation);
|
||||
|
||||
if (response.data.data.executeResult.success) {
|
||||
printSuccess('代码执行成功');
|
||||
console.log('数据预览(前5行):');
|
||||
console.log(JSON.stringify(response.data.data.executeResult.newDataPreview?.slice(0, 5), null, 2));
|
||||
}
|
||||
|
||||
return { success: true, data: response.data.data };
|
||||
} else {
|
||||
printError('处理失败: ' + response.data.error);
|
||||
return { success: false, error: response.data.error };
|
||||
}
|
||||
} catch (error) {
|
||||
printError('处理异常: ' + (error.response?.data?.error || error.message));
|
||||
return { success: false, error: error.message };
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 主测试函数 ====================
|
||||
|
||||
async function runAllTests() {
|
||||
console.log('\n' + '🚀'.repeat(35));
|
||||
console.log(' Tool C Day 3 测试 - AI代码生成');
|
||||
console.log('🚀'.repeat(35));
|
||||
|
||||
const results = {};
|
||||
|
||||
try {
|
||||
// 0. 准备测试Session
|
||||
const sessionCreated = await createTestSession();
|
||||
if (!sessionCreated) {
|
||||
printError('测试Session创建失败,无法继续');
|
||||
return;
|
||||
}
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 2000));
|
||||
|
||||
// ==================== 10个Few-shot示例测试 ====================
|
||||
|
||||
// 测试1: 统一缺失值标记
|
||||
let result = await testAIProcess(
|
||||
'示例1: 统一缺失值标记',
|
||||
'把所有代表缺失的符号(-、不详、NA、N/A)统一替换为标准空值'
|
||||
);
|
||||
results['示例1-缺失值'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试2: 数值列清洗
|
||||
result = await testAIProcess(
|
||||
'示例2: 数值列清洗',
|
||||
'把creatinine列里的非数字符号去掉,<0.1按0.05处理,转为数值类型'
|
||||
);
|
||||
results['示例2-数值清洗'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试3: 分类变量编码
|
||||
result = await testAIProcess(
|
||||
'示例3: 分类变量编码',
|
||||
'把gender列转为数字,男=1,女=0'
|
||||
);
|
||||
results['示例3-编码'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试4: 连续变量分箱
|
||||
result = await testAIProcess(
|
||||
'示例4: 连续变量分箱',
|
||||
'把age列按18岁、60岁分为未成年、成年、老年三组'
|
||||
);
|
||||
results['示例4-分箱'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试5: BMI计算
|
||||
result = await testAIProcess(
|
||||
'示例5: BMI计算',
|
||||
'根据weight和height计算BMI,并标记BMI≥28为肥胖'
|
||||
);
|
||||
results['示例5-BMI'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试6: 条件筛选
|
||||
result = await testAIProcess(
|
||||
'示例6: 条件筛选',
|
||||
'筛选出年龄≥60岁、且sbp≥140的患者'
|
||||
);
|
||||
results['示例6-筛选'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试7: 智能去重
|
||||
result = await testAIProcess(
|
||||
'示例7: 智能去重',
|
||||
'按patient_id去重,保留check_date最新的记录'
|
||||
);
|
||||
results['示例7-去重'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试8: 中位数填补(简化版,跳过多重插补)
|
||||
result = await testAIProcess(
|
||||
'示例8: 缺失值填补',
|
||||
'用age列的中位数填补age列的缺失值'
|
||||
);
|
||||
results['示例8-填补'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试9: 统计汇总
|
||||
result = await testAIProcess(
|
||||
'示例9: 统计汇总',
|
||||
'按diagnosis分组,统计每个诊断的平均年龄和患者数量'
|
||||
);
|
||||
results['示例9-统计'] = result.success;
|
||||
await new Promise(resolve => setTimeout(resolve, 3000));
|
||||
|
||||
// 测试10: 复杂计算
|
||||
result = await testAIProcess(
|
||||
'示例10: 复杂计算',
|
||||
'根据sbp判断血压分类:正常(<140)、高血压I级(140-159)、高血压II级(≥160)'
|
||||
);
|
||||
results['示例10-分类'] = result.success;
|
||||
|
||||
// ==================== 对话历史测试 ====================
|
||||
|
||||
printSection('测试: 获取对话历史');
|
||||
try {
|
||||
const historyResponse = await axios.get(
|
||||
`${BASE_URL}${API_PREFIX}/ai/history/${testSessionId}?limit=5`
|
||||
);
|
||||
|
||||
if (historyResponse.data.success) {
|
||||
printSuccess(`获取历史成功: ${historyResponse.data.data.count}条`);
|
||||
results['对话历史'] = true;
|
||||
} else {
|
||||
printError('获取历史失败');
|
||||
results['对话历史'] = false;
|
||||
}
|
||||
} catch (error) {
|
||||
printError('获取历史异常: ' + error.message);
|
||||
results['对话历史'] = false;
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
printError('测试过程中发生异常: ' + error.message);
|
||||
console.error(error);
|
||||
}
|
||||
|
||||
// 汇总结果
|
||||
printSection('测试结果汇总');
|
||||
|
||||
let passed = 0;
|
||||
let total = 0;
|
||||
|
||||
for (const [testName, result] of Object.entries(results)) {
|
||||
total++;
|
||||
if (result) {
|
||||
passed++;
|
||||
console.log(`${testName.padEnd(20)}: ✅ 通过`);
|
||||
} else {
|
||||
console.log(`${testName.padEnd(20)}: ❌ 失败`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log('\n' + '-'.repeat(70));
|
||||
console.log(`总计: ${passed}/${total} 通过 (${((passed/total)*100).toFixed(1)}%)`);
|
||||
console.log('-'.repeat(70));
|
||||
|
||||
if (passed === total) {
|
||||
console.log('\n🎉 所有测试通过!Day 3 AI功能完成!\n');
|
||||
} else if (passed >= total * 0.7) {
|
||||
console.log(`\n⚠️ 有 ${total - passed} 个测试失败,但通过率≥70%,基本可用\n`);
|
||||
} else {
|
||||
console.log(`\n❌ 通过率过低,需要调试\n`);
|
||||
}
|
||||
}
|
||||
|
||||
// 执行测试
|
||||
runAllTests()
|
||||
.then(() => {
|
||||
console.log('测试完成');
|
||||
process.exit(0);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('测试失败:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user