Files
HaHafeng f01981bf78 feat(dc/tool-c): 完成AI代码生成服务(Day 3 MVP)
核心功能:
- 新增AICodeService(550行):AI代码生成核心服务
- 新增AIController(257行):4个API端点
- 新增dc_tool_c_ai_history表:存储对话历史
- 实现自我修正机制:最多3次智能重试
- 集成LLMFactory:复用通用能力层
- 10个Few-shot示例:覆盖Level 1-4场景

技术优化:
- 修复NaN序列化问题(Python端转None)
- 修复数据传递问题(从Session获取真实数据)
- 优化System Prompt(明确环境信息)
- 调整Few-shot示例(移除import语句)

测试结果:
- 通过率:9/11(81.8%) 达到MVP标准
- 成功场景:缺失值处理、编码、分箱、BMI、筛选、填补、统计、分类
- 待优化:数值清洗、智能去重(已记录技术债务TD-C-006)

API端点:
- POST /api/v1/dc/tool-c/ai/generate(生成代码)
- POST /api/v1/dc/tool-c/ai/execute(执行代码)
- POST /api/v1/dc/tool-c/ai/process(生成并执行,一步到位)
- GET /api/v1/dc/tool-c/ai/history/:sessionId(对话历史)

文档更新:
- 新增Day 3开发完成总结(770行)
- 新增复杂场景优化技术债务(TD-C-006)
- 更新工具C当前状态文档
- 更新技术债务清单

影响范围:
- backend/src/modules/dc/tool-c/*(新增2个文件,更新1个文件)
- backend/scripts/create-tool-c-ai-history-table.mjs(新增)
- backend/prisma/schema.prisma(新增DcToolCAiHistory模型)
- extraction_service/services/dc_executor.py(NaN序列化修复)
- docs/03-业务模块/DC-数据清洗整理/*(5份文档更新)

Breaking Changes: 无

总代码行数:+950行

Refs: #Tool-C-Day3
2025-12-07 16:21:32 +08:00

428 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
DC工具C - Pandas代码执行服务
功能:
- AST静态代码检查安全验证
- Pandas代码执行沙箱环境
- 危险模块拦截
- 超时保护
"""
import ast
import sys
import pandas as pd
import numpy as np
import io
import traceback
from typing import Dict, Any, List, Tuple
from loguru import logger
import signal
import json
# ==================== 配置常量 ====================
# 危险模块黑名单(禁止导入)
DANGEROUS_MODULES = {
'os', 'sys', 'subprocess', 'shutil', 'glob',
'socket', 'urllib', 'requests', 'http',
'pickle', 'shelve', 'dbm',
'importlib', '__import__',
'eval', 'exec', 'compile',
'open', 'input', 'file',
}
# 危险内置函数黑名单
DANGEROUS_BUILTINS = {
'eval', 'exec', 'compile', '__import__',
'open', 'input', 'file',
'getattr', 'setattr', 'delattr',
'globals', 'locals', 'vars',
}
# 允许的安全模块
SAFE_MODULES = {
'pandas', 'numpy', 'pd', 'np',
'datetime', 'math', 'statistics',
'json', 're', 'collections',
}
# 代码执行超时时间(秒)
EXECUTION_TIMEOUT = 30
# ==================== AST安全检查 ====================
class SecurityVisitor(ast.NodeVisitor):
"""
AST访问器 - 检查代码中的危险操作
"""
def __init__(self):
self.errors: List[str] = []
self.warnings: List[str] = []
def visit_Import(self, node: ast.Import):
"""检查import语句"""
for alias in node.names:
module = alias.name.split('.')[0]
if module in DANGEROUS_MODULES:
self.errors.append(
f"🚫 禁止导入危险模块: {module} (行 {node.lineno})"
)
elif module not in SAFE_MODULES:
self.warnings.append(
f"⚠️ 导入了未知模块: {module} (行 {node.lineno})"
)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom):
"""检查from...import语句"""
if node.module:
module = node.module.split('.')[0]
if module in DANGEROUS_MODULES:
self.errors.append(
f"🚫 禁止从危险模块导入: {module} (行 {node.lineno})"
)
elif module not in SAFE_MODULES:
self.warnings.append(
f"⚠️ 从未知模块导入: {module} (行 {node.lineno})"
)
self.generic_visit(node)
def visit_Call(self, node: ast.Call):
"""检查函数调用"""
# 检查是否调用了危险内置函数
if isinstance(node.func, ast.Name):
func_name = node.func.id
if func_name in DANGEROUS_BUILTINS:
self.errors.append(
f"🚫 禁止调用危险函数: {func_name}() (行 {node.lineno})"
)
# 检查是否调用了open文件操作
if isinstance(node.func, ast.Attribute):
if node.func.attr == 'open':
self.errors.append(
f"🚫 禁止文件操作: open() (行 {node.lineno})"
)
self.generic_visit(node)
def visit_Delete(self, node: ast.Delete):
"""检查删除操作"""
# 不允许删除df变量本身
for target in node.targets:
if isinstance(target, ast.Name) and target.id == 'df':
self.errors.append(
f"🚫 禁止删除DataFrame变量: del df (行 {node.lineno})"
)
self.generic_visit(node)
def validate_code(code: str) -> Dict[str, Any]:
"""
AST静态代码检查
Args:
code: 待检查的Python代码
Returns:
{
"valid": bool,
"errors": List[str],
"warnings": List[str]
}
"""
try:
logger.info(f"开始AST代码检查代码长度: {len(code)} 字符")
# 解析AST
try:
tree = ast.parse(code)
except SyntaxError as e:
return {
"valid": False,
"errors": [f"❌ 语法错误 (行 {e.lineno}): {e.msg}"],
"warnings": []
}
# 安全检查
visitor = SecurityVisitor()
visitor.visit(tree)
# 额外检查代码中是否包含df变量
has_df = any(
isinstance(node, ast.Name) and node.id == 'df'
for node in ast.walk(tree)
)
if not has_df:
visitor.warnings.append(
"⚠️ 代码中未使用 df 变量,可能无法操作数据"
)
# 返回结果
is_valid = len(visitor.errors) == 0
logger.info(
f"AST检查完成: valid={is_valid}, "
f"errors={len(visitor.errors)}, warnings={len(visitor.warnings)}"
)
return {
"valid": is_valid,
"errors": visitor.errors,
"warnings": visitor.warnings
}
except Exception as e:
logger.error(f"AST检查失败: {str(e)}")
return {
"valid": False,
"errors": [f"❌ 检查失败: {str(e)}"],
"warnings": []
}
# ==================== 超时处理 ====================
class TimeoutException(Exception):
"""代码执行超时异常"""
pass
def timeout_handler(signum, frame):
"""超时信号处理器"""
raise TimeoutException("代码执行超时(>30秒")
# ==================== Pandas代码执行 ====================
def execute_pandas_code(data: List[Dict[str, Any]], code: str) -> Dict[str, Any]:
"""
在沙箱环境中执行Pandas代码
Args:
data: JSON格式的数据数组对象
code: Pandas代码必须操作df变量
Returns:
{
"success": bool,
"result_data": List[Dict], # 执行后的数据
"output": str, # 打印输出
"error": str, # 错误信息(如果失败)
"execution_time": float # 执行时间(秒)
}
"""
import time
start_time = time.time()
try:
logger.info(f"开始执行Pandas代码数据行数: {len(data)}")
# 1. 先进行AST检查
validation = validate_code(code)
if not validation["valid"]:
return {
"success": False,
"result_data": None,
"output": "",
"error": f"代码未通过安全检查:\n" + "\n".join(validation["errors"]),
"execution_time": time.time() - start_time
}
# 2. 创建DataFrame
try:
df = pd.DataFrame(data)
logger.info(f"DataFrame创建成功: shape={df.shape}")
except Exception as e:
return {
"success": False,
"result_data": None,
"output": "",
"error": f"数据转换失败: {str(e)}",
"execution_time": time.time() - start_time
}
# 3. 准备安全的执行环境
safe_globals = {
'pd': pd,
'np': np,
'df': df,
'__builtins__': {
# 只允许安全的内置函数
'len': len,
'range': range,
'enumerate': enumerate,
'zip': zip,
'map': map,
'filter': filter,
'list': list,
'dict': dict,
'set': set,
'tuple': tuple,
'str': str,
'int': int,
'float': float,
'bool': bool,
'print': print,
'sum': sum,
'min': min,
'max': max,
'abs': abs,
'round': round,
'sorted': sorted,
'reversed': reversed,
'any': any,
'all': all,
}
}
# 4. 捕获print输出
old_stdout = sys.stdout
sys.stdout = captured_output = io.StringIO()
try:
# 5. 设置超时保护仅在Unix系统上
if hasattr(signal, 'SIGALRM'):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(EXECUTION_TIMEOUT)
# 6. 执行代码
logger.info("执行用户代码...")
exec(code, safe_globals)
# 7. 取消超时
if hasattr(signal, 'SIGALRM'):
signal.alarm(0)
# 8. 获取执行后的DataFrame
df_result = safe_globals['df']
# 9. 验证结果
if not isinstance(df_result, pd.DataFrame):
raise ValueError(
f"执行后df不是DataFrame类型而是 {type(df_result)}"
)
logger.info(f"代码执行成功: shape={df_result.shape}")
# 10. 转换回JSON格式NaN转为None避免JSON序列化错误
result_data = df_result.replace({np.nan: None}).to_dict('records')
# 11. 获取print输出
output = captured_output.getvalue()
execution_time = time.time() - start_time
logger.info(f"执行完成,耗时: {execution_time:.3f}")
return {
"success": True,
"result_data": result_data,
"output": output.strip() if output else "",
"error": None,
"execution_time": execution_time,
"result_shape": df_result.shape
}
except TimeoutException as e:
logger.error(f"代码执行超时: {str(e)}")
return {
"success": False,
"result_data": None,
"output": captured_output.getvalue(),
"error": f"⏱️ 执行超时: 代码运行超过 {EXECUTION_TIMEOUT}",
"execution_time": time.time() - start_time
}
except Exception as e:
logger.error(f"代码执行失败: {str(e)}")
error_traceback = traceback.format_exc()
return {
"success": False,
"result_data": None,
"output": captured_output.getvalue(),
"error": f"❌ 执行错误:\n{str(e)}\n\n{error_traceback}",
"execution_time": time.time() - start_time
}
finally:
# 恢复stdout
sys.stdout = old_stdout
# 取消超时
if hasattr(signal, 'SIGALRM'):
signal.alarm(0)
except Exception as e:
logger.error(f"代码执行服务失败: {str(e)}")
return {
"success": False,
"result_data": None,
"output": "",
"error": f"服务错误: {str(e)}",
"execution_time": time.time() - start_time
}
# ==================== 测试函数 ====================
def test_dc_executor():
"""测试DC执行器"""
print("=" * 60)
print("测试1: AST安全检查 - 正常代码")
print("=" * 60)
safe_code = """
import pandas as pd
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
print(df['age_group'].value_counts())
"""
result = validate_code(safe_code)
print(json.dumps(result, indent=2, ensure_ascii=False))
print("\n" + "=" * 60)
print("测试2: AST安全检查 - 危险代码")
print("=" * 60)
dangerous_code = """
import os
import sys
os.system('rm -rf /')
"""
result = validate_code(dangerous_code)
print(json.dumps(result, indent=2, ensure_ascii=False))
print("\n" + "=" * 60)
print("测试3: 代码执行 - 简单操作")
print("=" * 60)
test_data = [
{"patient_id": "P001", "age": 25, "gender": ""},
{"patient_id": "P002", "age": 65, "gender": ""},
{"patient_id": "P003", "age": 45, "gender": ""},
]
simple_code = """
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
print(f"处理完成,共 {len(df)} 行")
"""
result = execute_pandas_code(test_data, simple_code)
print(json.dumps(result, indent=2, ensure_ascii=False))
print("\n" + "=" * 60)
print("测试完成!")
print("=" * 60)
if __name__ == "__main__":
test_dc_executor()