feat(dc/tool-c): 完成AI代码生成服务(Day 3 MVP)
核心功能: - 新增AICodeService(550行):AI代码生成核心服务 - 新增AIController(257行):4个API端点 - 新增dc_tool_c_ai_history表:存储对话历史 - 实现自我修正机制:最多3次智能重试 - 集成LLMFactory:复用通用能力层 - 10个Few-shot示例:覆盖Level 1-4场景 技术优化: - 修复NaN序列化问题(Python端转None) - 修复数据传递问题(从Session获取真实数据) - 优化System Prompt(明确环境信息) - 调整Few-shot示例(移除import语句) 测试结果: - 通过率:9/11(81.8%) 达到MVP标准 - 成功场景:缺失值处理、编码、分箱、BMI、筛选、填补、统计、分类 - 待优化:数值清洗、智能去重(已记录技术债务TD-C-006) API端点: - POST /api/v1/dc/tool-c/ai/generate(生成代码) - POST /api/v1/dc/tool-c/ai/execute(执行代码) - POST /api/v1/dc/tool-c/ai/process(生成并执行,一步到位) - GET /api/v1/dc/tool-c/ai/history/:sessionId(对话历史) 文档更新: - 新增Day 3开发完成总结(770行) - 新增复杂场景优化技术债务(TD-C-006) - 更新工具C当前状态文档 - 更新技术债务清单 影响范围: - backend/src/modules/dc/tool-c/*(新增2个文件,更新1个文件) - backend/scripts/create-tool-c-ai-history-table.mjs(新增) - backend/prisma/schema.prisma(新增DcToolCAiHistory模型) - extraction_service/services/dc_executor.py(NaN序列化修复) - docs/03-业务模块/DC-数据清洗整理/*(5份文档更新) Breaking Changes: 无 总代码行数:+950行 Refs: #Tool-C-Day3
This commit is contained in:
427
extraction_service/services/dc_executor.py
Normal file
427
extraction_service/services/dc_executor.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
DC工具C - Pandas代码执行服务
|
||||
|
||||
功能:
|
||||
- AST静态代码检查(安全验证)
|
||||
- Pandas代码执行(沙箱环境)
|
||||
- 危险模块拦截
|
||||
- 超时保护
|
||||
"""
|
||||
|
||||
import ast
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import io
|
||||
import traceback
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from loguru import logger
|
||||
import signal
|
||||
import json
|
||||
|
||||
|
||||
# ==================== 配置常量 ====================
|
||||
|
||||
# 危险模块黑名单(禁止导入)
|
||||
DANGEROUS_MODULES = {
|
||||
'os', 'sys', 'subprocess', 'shutil', 'glob',
|
||||
'socket', 'urllib', 'requests', 'http',
|
||||
'pickle', 'shelve', 'dbm',
|
||||
'importlib', '__import__',
|
||||
'eval', 'exec', 'compile',
|
||||
'open', 'input', 'file',
|
||||
}
|
||||
|
||||
# 危险内置函数黑名单
|
||||
DANGEROUS_BUILTINS = {
|
||||
'eval', 'exec', 'compile', '__import__',
|
||||
'open', 'input', 'file',
|
||||
'getattr', 'setattr', 'delattr',
|
||||
'globals', 'locals', 'vars',
|
||||
}
|
||||
|
||||
# 允许的安全模块
|
||||
SAFE_MODULES = {
|
||||
'pandas', 'numpy', 'pd', 'np',
|
||||
'datetime', 'math', 'statistics',
|
||||
'json', 're', 'collections',
|
||||
}
|
||||
|
||||
# 代码执行超时时间(秒)
|
||||
EXECUTION_TIMEOUT = 30
|
||||
|
||||
|
||||
# ==================== AST安全检查 ====================
|
||||
|
||||
class SecurityVisitor(ast.NodeVisitor):
|
||||
"""
|
||||
AST访问器 - 检查代码中的危险操作
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: List[str] = []
|
||||
self.warnings: List[str] = []
|
||||
|
||||
def visit_Import(self, node: ast.Import):
|
||||
"""检查import语句"""
|
||||
for alias in node.names:
|
||||
module = alias.name.split('.')[0]
|
||||
if module in DANGEROUS_MODULES:
|
||||
self.errors.append(
|
||||
f"🚫 禁止导入危险模块: {module} (行 {node.lineno})"
|
||||
)
|
||||
elif module not in SAFE_MODULES:
|
||||
self.warnings.append(
|
||||
f"⚠️ 导入了未知模块: {module} (行 {node.lineno})"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
"""检查from...import语句"""
|
||||
if node.module:
|
||||
module = node.module.split('.')[0]
|
||||
if module in DANGEROUS_MODULES:
|
||||
self.errors.append(
|
||||
f"🚫 禁止从危险模块导入: {module} (行 {node.lineno})"
|
||||
)
|
||||
elif module not in SAFE_MODULES:
|
||||
self.warnings.append(
|
||||
f"⚠️ 从未知模块导入: {module} (行 {node.lineno})"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
"""检查函数调用"""
|
||||
# 检查是否调用了危险内置函数
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_name = node.func.id
|
||||
if func_name in DANGEROUS_BUILTINS:
|
||||
self.errors.append(
|
||||
f"🚫 禁止调用危险函数: {func_name}() (行 {node.lineno})"
|
||||
)
|
||||
|
||||
# 检查是否调用了open文件操作
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
if node.func.attr == 'open':
|
||||
self.errors.append(
|
||||
f"🚫 禁止文件操作: open() (行 {node.lineno})"
|
||||
)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Delete(self, node: ast.Delete):
|
||||
"""检查删除操作"""
|
||||
# 不允许删除df变量本身
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == 'df':
|
||||
self.errors.append(
|
||||
f"🚫 禁止删除DataFrame变量: del df (行 {node.lineno})"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def validate_code(code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
AST静态代码检查
|
||||
|
||||
Args:
|
||||
code: 待检查的Python代码
|
||||
|
||||
Returns:
|
||||
{
|
||||
"valid": bool,
|
||||
"errors": List[str],
|
||||
"warnings": List[str]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始AST代码检查,代码长度: {len(code)} 字符")
|
||||
|
||||
# 解析AST
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": [f"❌ 语法错误 (行 {e.lineno}): {e.msg}"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# 安全检查
|
||||
visitor = SecurityVisitor()
|
||||
visitor.visit(tree)
|
||||
|
||||
# 额外检查:代码中是否包含df变量
|
||||
has_df = any(
|
||||
isinstance(node, ast.Name) and node.id == 'df'
|
||||
for node in ast.walk(tree)
|
||||
)
|
||||
|
||||
if not has_df:
|
||||
visitor.warnings.append(
|
||||
"⚠️ 代码中未使用 df 变量,可能无法操作数据"
|
||||
)
|
||||
|
||||
# 返回结果
|
||||
is_valid = len(visitor.errors) == 0
|
||||
|
||||
logger.info(
|
||||
f"AST检查完成: valid={is_valid}, "
|
||||
f"errors={len(visitor.errors)}, warnings={len(visitor.warnings)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"valid": is_valid,
|
||||
"errors": visitor.errors,
|
||||
"warnings": visitor.warnings
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AST检查失败: {str(e)}")
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": [f"❌ 检查失败: {str(e)}"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
|
||||
# ==================== 超时处理 ====================
|
||||
|
||||
class TimeoutException(Exception):
|
||||
"""代码执行超时异常"""
|
||||
pass
|
||||
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
"""超时信号处理器"""
|
||||
raise TimeoutException("代码执行超时(>30秒)")
|
||||
|
||||
|
||||
# ==================== Pandas代码执行 ====================
|
||||
|
||||
def execute_pandas_code(data: List[Dict[str, Any]], code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
在沙箱环境中执行Pandas代码
|
||||
|
||||
Args:
|
||||
data: JSON格式的数据(数组对象)
|
||||
code: Pandas代码(必须操作df变量)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"result_data": List[Dict], # 执行后的数据
|
||||
"output": str, # 打印输出
|
||||
"error": str, # 错误信息(如果失败)
|
||||
"execution_time": float # 执行时间(秒)
|
||||
}
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"开始执行Pandas代码,数据行数: {len(data)}")
|
||||
|
||||
# 1. 先进行AST检查
|
||||
validation = validate_code(code)
|
||||
if not validation["valid"]:
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": "",
|
||||
"error": f"代码未通过安全检查:\n" + "\n".join(validation["errors"]),
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
# 2. 创建DataFrame
|
||||
try:
|
||||
df = pd.DataFrame(data)
|
||||
logger.info(f"DataFrame创建成功: shape={df.shape}")
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": "",
|
||||
"error": f"数据转换失败: {str(e)}",
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
# 3. 准备安全的执行环境
|
||||
safe_globals = {
|
||||
'pd': pd,
|
||||
'np': np,
|
||||
'df': df,
|
||||
'__builtins__': {
|
||||
# 只允许安全的内置函数
|
||||
'len': len,
|
||||
'range': range,
|
||||
'enumerate': enumerate,
|
||||
'zip': zip,
|
||||
'map': map,
|
||||
'filter': filter,
|
||||
'list': list,
|
||||
'dict': dict,
|
||||
'set': set,
|
||||
'tuple': tuple,
|
||||
'str': str,
|
||||
'int': int,
|
||||
'float': float,
|
||||
'bool': bool,
|
||||
'print': print,
|
||||
'sum': sum,
|
||||
'min': min,
|
||||
'max': max,
|
||||
'abs': abs,
|
||||
'round': round,
|
||||
'sorted': sorted,
|
||||
'reversed': reversed,
|
||||
'any': any,
|
||||
'all': all,
|
||||
}
|
||||
}
|
||||
|
||||
# 4. 捕获print输出
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = captured_output = io.StringIO()
|
||||
|
||||
try:
|
||||
# 5. 设置超时保护(仅在Unix系统上)
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(EXECUTION_TIMEOUT)
|
||||
|
||||
# 6. 执行代码
|
||||
logger.info("执行用户代码...")
|
||||
exec(code, safe_globals)
|
||||
|
||||
# 7. 取消超时
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
signal.alarm(0)
|
||||
|
||||
# 8. 获取执行后的DataFrame
|
||||
df_result = safe_globals['df']
|
||||
|
||||
# 9. 验证结果
|
||||
if not isinstance(df_result, pd.DataFrame):
|
||||
raise ValueError(
|
||||
f"执行后df不是DataFrame类型,而是 {type(df_result)}"
|
||||
)
|
||||
|
||||
logger.info(f"代码执行成功: shape={df_result.shape}")
|
||||
|
||||
# 10. 转换回JSON格式(NaN转为None避免JSON序列化错误)
|
||||
result_data = df_result.replace({np.nan: None}).to_dict('records')
|
||||
|
||||
# 11. 获取print输出
|
||||
output = captured_output.getvalue()
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"执行完成,耗时: {execution_time:.3f}秒")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result_data": result_data,
|
||||
"output": output.strip() if output else "",
|
||||
"error": None,
|
||||
"execution_time": execution_time,
|
||||
"result_shape": df_result.shape
|
||||
}
|
||||
|
||||
except TimeoutException as e:
|
||||
logger.error(f"代码执行超时: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": captured_output.getvalue(),
|
||||
"error": f"⏱️ 执行超时: 代码运行超过 {EXECUTION_TIMEOUT} 秒",
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"代码执行失败: {str(e)}")
|
||||
error_traceback = traceback.format_exc()
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": captured_output.getvalue(),
|
||||
"error": f"❌ 执行错误:\n{str(e)}\n\n{error_traceback}",
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
finally:
|
||||
# 恢复stdout
|
||||
sys.stdout = old_stdout
|
||||
|
||||
# 取消超时
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
signal.alarm(0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"代码执行服务失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": "",
|
||||
"error": f"服务错误: {str(e)}",
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
|
||||
# ==================== 测试函数 ====================
|
||||
|
||||
def test_dc_executor():
|
||||
"""测试DC执行器"""
|
||||
|
||||
print("=" * 60)
|
||||
print("测试1: AST安全检查 - 正常代码")
|
||||
print("=" * 60)
|
||||
|
||||
safe_code = """
|
||||
import pandas as pd
|
||||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||||
print(df['age_group'].value_counts())
|
||||
"""
|
||||
|
||||
result = validate_code(safe_code)
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试2: AST安全检查 - 危险代码")
|
||||
print("=" * 60)
|
||||
|
||||
dangerous_code = """
|
||||
import os
|
||||
import sys
|
||||
os.system('rm -rf /')
|
||||
"""
|
||||
|
||||
result = validate_code(dangerous_code)
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试3: 代码执行 - 简单操作")
|
||||
print("=" * 60)
|
||||
|
||||
test_data = [
|
||||
{"patient_id": "P001", "age": 25, "gender": "男"},
|
||||
{"patient_id": "P002", "age": 65, "gender": "女"},
|
||||
{"patient_id": "P003", "age": 45, "gender": "男"},
|
||||
]
|
||||
|
||||
simple_code = """
|
||||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||||
print(f"处理完成,共 {len(df)} 行")
|
||||
"""
|
||||
|
||||
result = execute_pandas_code(test_data, simple_code)
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dc_executor()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user