feat(dc/tool-c): 完成AI代码生成服务(Day 3 MVP)
核心功能: - 新增AICodeService(550行):AI代码生成核心服务 - 新增AIController(257行):4个API端点 - 新增dc_tool_c_ai_history表:存储对话历史 - 实现自我修正机制:最多3次智能重试 - 集成LLMFactory:复用通用能力层 - 10个Few-shot示例:覆盖Level 1-4场景 技术优化: - 修复NaN序列化问题(Python端转None) - 修复数据传递问题(从Session获取真实数据) - 优化System Prompt(明确环境信息) - 调整Few-shot示例(移除import语句) 测试结果: - 通过率:9/11(81.8%) 达到MVP标准 - 成功场景:缺失值处理、编码、分箱、BMI、筛选、填补、统计、分类 - 待优化:数值清洗、智能去重(已记录技术债务TD-C-006) API端点: - POST /api/v1/dc/tool-c/ai/generate(生成代码) - POST /api/v1/dc/tool-c/ai/execute(执行代码) - POST /api/v1/dc/tool-c/ai/process(生成并执行,一步到位) - GET /api/v1/dc/tool-c/ai/history/:sessionId(对话历史) 文档更新: - 新增Day 3开发完成总结(770行) - 新增复杂场景优化技术债务(TD-C-006) - 更新工具C当前状态文档 - 更新技术债务清单 影响范围: - backend/src/modules/dc/tool-c/*(新增2个文件,更新1个文件) - backend/scripts/create-tool-c-ai-history-table.mjs(新增) - backend/prisma/schema.prisma(新增DcToolCAiHistory模型) - extraction_service/services/dc_executor.py(NaN序列化修复) - docs/03-业务模块/DC-数据清洗整理/*(5份文档更新) Breaking Changes: 无 总代码行数:+950行 Refs: #Tool-C-Day3
This commit is contained in:
549
backend/src/modules/dc/tool-c/services/AICodeService.ts
Normal file
549
backend/src/modules/dc/tool-c/services/AICodeService.ts
Normal file
@@ -0,0 +1,549 @@
|
||||
/**
|
||||
* AI代码生成服务
|
||||
*
|
||||
* 功能:
|
||||
* - 使用LLM生成Pandas数据清洗代码
|
||||
* - 执行生成的代码
|
||||
* - 自我修正(最多3次重试)
|
||||
* - 管理对话历史
|
||||
*
|
||||
* @module AICodeService
|
||||
*/
|
||||
|
||||
import { logger } from '../../../../common/logging/index.js';
|
||||
import { prisma } from '../../../../config/database.js';
|
||||
import { LLMFactory } from '../../../../common/llm/adapters/LLMFactory.js';
|
||||
import { ModelType, Message } from '../../../../common/llm/adapters/types.js';
|
||||
import { sessionService } from './SessionService.js';
|
||||
import { pythonExecutorService } from './PythonExecutorService.js';
|
||||
|
||||
// ==================== 类型定义 ====================
|
||||
|
||||
interface SessionData {
|
||||
id: string;
|
||||
fileName: string;
|
||||
totalRows: number;
|
||||
totalCols: number;
|
||||
columns: string[];
|
||||
}
|
||||
|
||||
interface GenerateCodeResult {
|
||||
code: string;
|
||||
explanation: string;
|
||||
messageId: string;
|
||||
}
|
||||
|
||||
interface ExecuteCodeResult {
|
||||
success: boolean;
|
||||
result?: any;
|
||||
error?: string;
|
||||
newDataPreview?: any[];
|
||||
}
|
||||
|
||||
interface ProcessResult extends GenerateCodeResult {
|
||||
executeResult: ExecuteCodeResult;
|
||||
retryCount: number;
|
||||
}
|
||||
|
||||
// ==================== AI代码生成服务 ====================
|
||||
|
||||
export class AICodeService {
|
||||
|
||||
/**
|
||||
* 生成Pandas代码
|
||||
* @param sessionId - Tool C Session ID
|
||||
* @param userMessage - 用户自然语言需求
|
||||
* @returns { code, explanation, messageId }
|
||||
*/
|
||||
async generateCode(
|
||||
sessionId: string,
|
||||
userMessage: string
|
||||
): Promise<GenerateCodeResult> {
|
||||
try {
|
||||
logger.info(`[AICodeService] 生成代码: sessionId=${sessionId}`);
|
||||
|
||||
// 1. 获取Session信息(数据集元数据)
|
||||
const session = await sessionService.getSession(sessionId);
|
||||
|
||||
// 2. 构建System Prompt(含10个Few-shot示例)
|
||||
const systemPrompt = this.buildSystemPrompt({
|
||||
id: session.id,
|
||||
fileName: session.fileName,
|
||||
totalRows: session.totalRows,
|
||||
totalCols: session.totalCols,
|
||||
columns: session.columns
|
||||
});
|
||||
|
||||
// 3. 获取对话历史(最近5轮)
|
||||
const history = await this.getHistory(sessionId, 5);
|
||||
|
||||
// 4. 调用LLM(复用LLMFactory)
|
||||
const llm = LLMFactory.getAdapter('deepseek-v3' as ModelType);
|
||||
const response = await llm.chat([
|
||||
{ role: 'system', content: systemPrompt },
|
||||
...history,
|
||||
{ role: 'user', content: userMessage }
|
||||
], {
|
||||
temperature: 0.1, // 低温度,确保代码准确性
|
||||
maxTokens: 2000, // 足够生成代码+解释
|
||||
topP: 0.9
|
||||
});
|
||||
|
||||
logger.info(`[AICodeService] LLM响应成功,开始解析...`);
|
||||
|
||||
// 5. 解析AI回复(提取code和explanation)
|
||||
const parsed = this.parseAIResponse(response.content);
|
||||
|
||||
// 6. 保存到数据库
|
||||
const messageId = await this.saveMessages(
|
||||
sessionId,
|
||||
session.userId,
|
||||
userMessage,
|
||||
parsed.code,
|
||||
parsed.explanation
|
||||
);
|
||||
|
||||
logger.info(`[AICodeService] 代码生成成功: messageId=${messageId}`);
|
||||
|
||||
return {
|
||||
code: parsed.code,
|
||||
explanation: parsed.explanation,
|
||||
messageId
|
||||
};
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 生成代码失败: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行Python代码
|
||||
* @param sessionId - Tool C Session ID
|
||||
* @param code - Python代码
|
||||
* @param messageId - 关联的消息ID
|
||||
* @returns { success, result, newDataPreview }
|
||||
*/
|
||||
async executeCode(
|
||||
sessionId: string,
|
||||
code: string,
|
||||
messageId: string
|
||||
): Promise<ExecuteCodeResult> {
|
||||
try {
|
||||
logger.info(`[AICodeService] 执行代码: messageId=${messageId}`);
|
||||
|
||||
// 1. 从Session获取完整数据
|
||||
const fullData = await sessionService.getFullData(sessionId);
|
||||
logger.info(`[AICodeService] 获取Session数据: ${fullData.length}行`);
|
||||
|
||||
// 2. 调用Python服务执行
|
||||
const result = await pythonExecutorService.executeCode(
|
||||
fullData,
|
||||
code
|
||||
);
|
||||
|
||||
// 2. 更新消息状态
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
await prisma.dcToolCAiHistory.update({
|
||||
where: { id: messageId },
|
||||
data: {
|
||||
executeStatus: result.success ? 'success' : 'failed',
|
||||
executeResult: result.result_data ? JSON.parse(JSON.stringify({ data: result.result_data })) : undefined,
|
||||
executeError: result.error || undefined
|
||||
}
|
||||
});
|
||||
|
||||
// 4. 如果成功,获取新数据预览(前50行)
|
||||
if (result.success && result.result_data) {
|
||||
const preview = Array.isArray(result.result_data)
|
||||
? result.result_data.slice(0, 50)
|
||||
: result.result_data;
|
||||
|
||||
logger.info(`[AICodeService] 代码执行成功`);
|
||||
|
||||
return {
|
||||
success: true,
|
||||
result: result.result_data,
|
||||
newDataPreview: preview
|
||||
};
|
||||
}
|
||||
|
||||
logger.warn(`[AICodeService] 代码执行失败: ${result.error}`);
|
||||
|
||||
return {
|
||||
success: false,
|
||||
error: result.error || '执行失败,未知错误'
|
||||
};
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 执行代码异常: ${error.message}`);
|
||||
|
||||
// 更新为失败状态
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
await prisma.dcToolCAiHistory.update({
|
||||
where: { id: messageId },
|
||||
data: {
|
||||
executeStatus: 'failed',
|
||||
executeError: error.message
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
success: false,
|
||||
error: error.message
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成并执行(带自我修正)
|
||||
* @param sessionId - Tool C Session ID
|
||||
* @param userMessage - 用户需求
|
||||
* @param maxRetries - 最大重试次数(默认3)
|
||||
* @returns { code, explanation, executeResult, retryCount }
|
||||
*/
|
||||
async generateAndExecute(
|
||||
sessionId: string,
|
||||
userMessage: string,
|
||||
maxRetries: number = 3
|
||||
): Promise<ProcessResult> {
|
||||
let attempt = 0;
|
||||
let lastError: string | null = null;
|
||||
let generated: GenerateCodeResult | null = null;
|
||||
|
||||
while (attempt < maxRetries) {
|
||||
try {
|
||||
logger.info(`[AICodeService] 尝试 ${attempt + 1}/${maxRetries}`);
|
||||
|
||||
// 构建带错误反馈的提示词
|
||||
const enhancedMessage = attempt === 0
|
||||
? userMessage
|
||||
: `${userMessage}\n\n上次执行错误:${lastError}\n请修正代码,确保代码正确且符合Pandas语法。`;
|
||||
|
||||
// 生成代码
|
||||
generated = await this.generateCode(sessionId, enhancedMessage);
|
||||
|
||||
// 执行代码
|
||||
const executeResult = await this.executeCode(
|
||||
sessionId,
|
||||
generated.code,
|
||||
generated.messageId
|
||||
);
|
||||
|
||||
if (executeResult.success) {
|
||||
// ✅ 成功
|
||||
logger.info(`[AICodeService] 执行成功(尝试${attempt + 1}次)`);
|
||||
|
||||
// 更新重试次数
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
await prisma.dcToolCAiHistory.update({
|
||||
where: { id: generated.messageId },
|
||||
data: { retryCount: attempt }
|
||||
});
|
||||
|
||||
return {
|
||||
...generated,
|
||||
executeResult,
|
||||
retryCount: attempt
|
||||
};
|
||||
}
|
||||
|
||||
// ❌ 失败,准备重试
|
||||
lastError = executeResult.error || '未知错误';
|
||||
attempt++;
|
||||
|
||||
logger.warn(`[AICodeService] 执行失败(尝试${attempt}/${maxRetries}): ${lastError}`);
|
||||
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 异常: ${error.message}`);
|
||||
lastError = error.message;
|
||||
attempt++;
|
||||
}
|
||||
}
|
||||
|
||||
// 3次仍失败
|
||||
throw new Error(
|
||||
`代码执行失败(已重试${maxRetries}次)。最后错误:${lastError}。` +
|
||||
`建议:请调整需求描述或检查数据列名是否正确。`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取对话历史
|
||||
* @param sessionId - Tool C Session ID
|
||||
* @param limit - 最近N轮对话(默认5轮,即10条消息)
|
||||
* @returns 消息列表
|
||||
*/
|
||||
async getHistory(sessionId: string, limit: number = 5): Promise<Message[]> {
|
||||
try {
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
const records = await prisma.dcToolCAiHistory.findMany({
|
||||
where: { sessionId },
|
||||
orderBy: { createdAt: 'desc' },
|
||||
take: limit * 2, // user + assistant 成对
|
||||
select: {
|
||||
role: true,
|
||||
content: true
|
||||
}
|
||||
});
|
||||
|
||||
// 反转顺序(最旧的在前)
|
||||
return records.reverse().map((r: any) => ({
|
||||
role: r.role as 'user' | 'assistant' | 'system',
|
||||
content: r.content
|
||||
}));
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 获取历史失败: ${error.message}`);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
/**
|
||||
* 构建System Prompt(含10个Few-shot示例)
|
||||
* @private
|
||||
*/
|
||||
private buildSystemPrompt(session: SessionData): string {
|
||||
return `你是医疗科研数据清洗专家,负责生成Pandas代码来清洗整理数据。
|
||||
|
||||
## 当前数据集信息
|
||||
- 文件名: ${session.fileName}
|
||||
- 行数: ${session.totalRows}
|
||||
- 列数: ${session.totalCols}
|
||||
- 列名: ${session.columns.join(', ')}
|
||||
|
||||
## 执行环境(重要)
|
||||
**已预导入的库(请直接使用,不要再import):**
|
||||
- pandas 已导入为 pd
|
||||
- numpy 已导入为 np
|
||||
- df 变量已加载当前数据集
|
||||
|
||||
**不可用的库:**
|
||||
- sklearn(未安装,请使用pandas/numpy替代方案)
|
||||
- scipy(未安装)
|
||||
- 其他第三方库
|
||||
|
||||
**示例:直接使用,无需导入**
|
||||
\`\`\`python
|
||||
# ✅ 正确:直接使用预导入的库
|
||||
df['age_clean'] = df['age'].fillna(df['age'].median())
|
||||
df['group'] = np.where(df['age'] > 60, '老年', '非老年')
|
||||
|
||||
# ❌ 错误:不要再导入
|
||||
import pandas as pd # 会报错!
|
||||
import numpy as np # 会报错!
|
||||
import sklearn # 未安装,会报错!
|
||||
\`\`\`
|
||||
|
||||
## 安全规则(强制)
|
||||
1. 只能操作df变量,不能修改其他变量
|
||||
2. **禁止任何import语句**(pandas和numpy已预导入)
|
||||
3. 禁止使用eval、exec、__import__等危险函数
|
||||
4. 必须进行异常处理
|
||||
5. 返回格式必须是JSON: {"code": "...", "explanation": "..."}
|
||||
|
||||
## Few-shot示例
|
||||
|
||||
### 示例1: 统一缺失值标记
|
||||
用户: 把所有代表缺失的符号(-、不详、NA、N/A)统一替换为标准空值
|
||||
代码:
|
||||
\`\`\`python
|
||||
try:
|
||||
df = df.replace(['-', '不详', 'NA', 'N/A', '\\\\', '未查'], np.nan)
|
||||
print(f'缺失值标记统一完成,当前缺失值数量: {df.isna().sum().sum()}')
|
||||
except Exception as e:
|
||||
print(f'处理错误: {e}')
|
||||
\`\`\`
|
||||
说明: 将多种缺失值表示统一为NaN,便于后续统计分析
|
||||
|
||||
### 示例2: 数值列清洗
|
||||
用户: 把肌酐列里的非数字符号去掉,<0.1按0.05处理,转为数值类型
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['creatinine'] = df['creatinine'].astype(str).str.replace('>', '').str.replace('<', '')
|
||||
df.loc[df['creatinine'] == '0.1', 'creatinine'] = '0.05'
|
||||
df['creatinine'] = pd.to_numeric(df['creatinine'], errors='coerce')
|
||||
\`\`\`
|
||||
说明: 检验科数据常含符号,需清理后才能计算
|
||||
|
||||
### 示例3: 分类变量编码
|
||||
用户: 把性别列转为数字,男=1,女=0
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['gender_code'] = df['gender'].map({'男': 1, '女': 0})
|
||||
\`\`\`
|
||||
说明: 将文本分类变量转为数值,便于统计建模
|
||||
|
||||
### 示例4: 连续变量分箱
|
||||
用户: 把年龄按18岁、60岁分为未成年、成年、老年三组
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['age_group'] = pd.cut(df['age'], bins=[0, 18, 60, 120], labels=['未成年', '成年', '老年'], right=False)
|
||||
\`\`\`
|
||||
说明: 将连续变量离散化,用于分层分析或卡方检验
|
||||
|
||||
### 示例5: BMI计算与分类
|
||||
用户: 根据身高(cm)和体重(kg)计算BMI,并标记BMI≥28为肥胖
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['BMI'] = df['weight'] / (df['height'] / 100) ** 2
|
||||
df['obesity'] = df['BMI'].apply(lambda x: '肥胖' if x >= 28 else '正常')
|
||||
\`\`\`
|
||||
说明: 临床常用的体质指标计算和分类
|
||||
|
||||
### 示例6: 日期计算
|
||||
用户: 根据入院日期和出院日期计算住院天数
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['admission_date'] = pd.to_datetime(df['admission_date'])
|
||||
df['discharge_date'] = pd.to_datetime(df['discharge_date'])
|
||||
df['length_of_stay'] = (df['discharge_date'] - df['admission_date']).dt.days
|
||||
\`\`\`
|
||||
说明: 医疗数据常需计算时间间隔(住院天数、随访时间等)
|
||||
|
||||
### 示例7: 条件筛选(入组标准)
|
||||
用户: 筛选出年龄≥18岁、诊断为糖尿病、且血糖≥7.0的患者
|
||||
代码:
|
||||
\`\`\`python
|
||||
df_selected = df[(df['age'] >= 18) & (df['diagnosis'] == '糖尿病') & (df['glucose'] >= 7.0)]
|
||||
\`\`\`
|
||||
说明: 临床研究常需根据入组/排除标准筛选病例
|
||||
|
||||
### 示例8: 简单缺失值填补
|
||||
用户: 用中位数填补BMI列的缺失值
|
||||
代码:
|
||||
\`\`\`python
|
||||
bmi_median = df['BMI'].median()
|
||||
df['BMI'] = df['BMI'].fillna(bmi_median)
|
||||
\`\`\`
|
||||
说明: 简单填补适用于缺失率<5%且MCAR(完全随机缺失)的情况
|
||||
|
||||
### 示例9: 智能多列缺失值填补
|
||||
用户: 对BMI、年龄、肌酐列的缺失值进行智能填补
|
||||
代码:
|
||||
\`\`\`python
|
||||
try:
|
||||
# 检查列是否存在
|
||||
cols = ['BMI', 'age', 'creatinine']
|
||||
missing_cols = [c for c in cols if c not in df.columns]
|
||||
if missing_cols:
|
||||
print(f'警告:以下列不存在: {missing_cols}')
|
||||
else:
|
||||
# 转换为数值类型
|
||||
for col in cols:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
|
||||
# 根据列特性选择填补策略
|
||||
df['age'] = df['age'].fillna(df['age'].median()) # 年龄用中位数
|
||||
df['BMI'] = df['BMI'].fillna(df.groupby('gender')['BMI'].transform('median')) # BMI按性别分组填补
|
||||
df['creatinine'] = df['creatinine'].fillna(df['creatinine'].mean()) # 肌酐用均值
|
||||
|
||||
print('缺失值填补完成')
|
||||
print(f'年龄缺失: {df["age"].isna().sum()}')
|
||||
print(f'BMI缺失: {df["BMI"].isna().sum()}')
|
||||
print(f'肌酐缺失: {df["creatinine"].isna().sum()}')
|
||||
except Exception as e:
|
||||
print(f'填补错误: {e}')
|
||||
\`\`\`
|
||||
说明: 根据医学变量特性选择不同填补策略:年龄用中位数(稳健),BMI按性别分组(考虑性别差异),肌酐用均值
|
||||
|
||||
### 示例10: 智能去重
|
||||
用户: 按患者ID去重,保留检查日期最新的记录
|
||||
代码:
|
||||
\`\`\`python
|
||||
df['check_date'] = pd.to_datetime(df['check_date'])
|
||||
df = df.sort_values('check_date').drop_duplicates(subset=['patient_id'], keep='last')
|
||||
\`\`\`
|
||||
说明: 先按日期排序,再去重保留最后一条(最新)
|
||||
|
||||
## 用户当前请求
|
||||
请根据以上示例和当前数据集信息,生成代码并解释。返回JSON格式:{"code": "...", "explanation": "..."}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析AI回复(提取code和explanation)
|
||||
* @private
|
||||
*/
|
||||
private parseAIResponse(content: string): { code: string; explanation: string } {
|
||||
try {
|
||||
// 方法1:尝试解析JSON
|
||||
const json = JSON.parse(content);
|
||||
if (json.code && json.explanation) {
|
||||
return { code: json.code, explanation: json.explanation };
|
||||
}
|
||||
} catch {
|
||||
// 方法2:正则提取代码块
|
||||
const codeMatch = content.match(/```python\n([\s\S]+?)\n```/);
|
||||
const code = codeMatch ? codeMatch[1].trim() : '';
|
||||
|
||||
// 提取解释(代码块之外的文本)
|
||||
let explanation = content.replace(/```python[\s\S]+?```/g, '').trim();
|
||||
|
||||
// 如果没有单独的解释,尝试提取JSON中的explanation
|
||||
try {
|
||||
const jsonMatch = content.match(/\{[\s\S]*"explanation":\s*"([^"]+)"[\s\S]*\}/);
|
||||
if (jsonMatch) {
|
||||
explanation = jsonMatch[1];
|
||||
}
|
||||
} catch {
|
||||
// 忽略
|
||||
}
|
||||
|
||||
if (code) {
|
||||
return { code, explanation: explanation || '代码已生成' };
|
||||
}
|
||||
}
|
||||
|
||||
logger.error(`[AICodeService] AI回复格式错误: ${content.substring(0, 200)}`);
|
||||
throw new Error('AI回复格式错误,无法提取代码。请重试。');
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存消息到数据库
|
||||
* @private
|
||||
*/
|
||||
private async saveMessages(
|
||||
sessionId: string,
|
||||
userId: string,
|
||||
userMessage: string,
|
||||
code: string,
|
||||
explanation: string
|
||||
): Promise<string> {
|
||||
try {
|
||||
// 保存用户消息
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
await prisma.dcToolCAiHistory.create({
|
||||
data: {
|
||||
sessionId,
|
||||
userId,
|
||||
role: 'user',
|
||||
content: userMessage
|
||||
}
|
||||
});
|
||||
|
||||
// 保存AI回复
|
||||
// @ts-ignore - DcToolCAiHistory模型
|
||||
const assistantMessage = await prisma.dcToolCAiHistory.create({
|
||||
data: {
|
||||
sessionId,
|
||||
userId,
|
||||
role: 'assistant',
|
||||
content: explanation,
|
||||
generatedCode: code,
|
||||
codeExplanation: explanation,
|
||||
executeStatus: 'pending',
|
||||
model: 'deepseek-v3'
|
||||
}
|
||||
});
|
||||
|
||||
return assistantMessage.id;
|
||||
} catch (error: any) {
|
||||
logger.error(`[AICodeService] 保存消息失败: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 导出单例实例 ====================
|
||||
|
||||
export const aiCodeService = new AICodeService();
|
||||
|
||||
Reference in New Issue
Block a user