核心功能: - 新增AICodeService(550行):AI代码生成核心服务 - 新增AIController(257行):4个API端点 - 新增dc_tool_c_ai_history表:存储对话历史 - 实现自我修正机制:最多3次智能重试 - 集成LLMFactory:复用通用能力层 - 10个Few-shot示例:覆盖Level 1-4场景 技术优化: - 修复NaN序列化问题(Python端转None) - 修复数据传递问题(从Session获取真实数据) - 优化System Prompt(明确环境信息) - 调整Few-shot示例(移除import语句) 测试结果: - 通过率:9/11(81.8%) 达到MVP标准 - 成功场景:缺失值处理、编码、分箱、BMI、筛选、填补、统计、分类 - 待优化:数值清洗、智能去重(已记录技术债务TD-C-006) API端点: - POST /api/v1/dc/tool-c/ai/generate(生成代码) - POST /api/v1/dc/tool-c/ai/execute(执行代码) - POST /api/v1/dc/tool-c/ai/process(生成并执行,一步到位) - GET /api/v1/dc/tool-c/ai/history/:sessionId(对话历史) 文档更新: - 新增Day 3开发完成总结(770行) - 新增复杂场景优化技术债务(TD-C-006) - 更新工具C当前状态文档 - 更新技术债务清单 影响范围: - backend/src/modules/dc/tool-c/*(新增2个文件,更新1个文件) - backend/scripts/create-tool-c-ai-history-table.mjs(新增) - backend/prisma/schema.prisma(新增DcToolCAiHistory模型) - extraction_service/services/dc_executor.py(NaN序列化修复) - docs/03-业务模块/DC-数据清洗整理/*(5份文档更新) Breaking Changes: 无 总代码行数:+950行 Refs: #Tool-C-Day3
428 lines
12 KiB
Python
428 lines
12 KiB
Python
"""
|
||
DC工具C - Pandas代码执行服务
|
||
|
||
功能:
|
||
- AST静态代码检查(安全验证)
|
||
- Pandas代码执行(沙箱环境)
|
||
- 危险模块拦截
|
||
- 超时保护
|
||
"""
|
||
|
||
import ast
|
||
import sys
|
||
import pandas as pd
|
||
import numpy as np
|
||
import io
|
||
import traceback
|
||
from typing import Dict, Any, List, Tuple
|
||
from loguru import logger
|
||
import signal
|
||
import json
|
||
|
||
|
||
# ==================== 配置常量 ====================
|
||
|
||
# 危险模块黑名单(禁止导入)
|
||
DANGEROUS_MODULES = {
|
||
'os', 'sys', 'subprocess', 'shutil', 'glob',
|
||
'socket', 'urllib', 'requests', 'http',
|
||
'pickle', 'shelve', 'dbm',
|
||
'importlib', '__import__',
|
||
'eval', 'exec', 'compile',
|
||
'open', 'input', 'file',
|
||
}
|
||
|
||
# 危险内置函数黑名单
|
||
DANGEROUS_BUILTINS = {
|
||
'eval', 'exec', 'compile', '__import__',
|
||
'open', 'input', 'file',
|
||
'getattr', 'setattr', 'delattr',
|
||
'globals', 'locals', 'vars',
|
||
}
|
||
|
||
# 允许的安全模块
|
||
SAFE_MODULES = {
|
||
'pandas', 'numpy', 'pd', 'np',
|
||
'datetime', 'math', 'statistics',
|
||
'json', 're', 'collections',
|
||
}
|
||
|
||
# 代码执行超时时间(秒)
|
||
EXECUTION_TIMEOUT = 30
|
||
|
||
|
||
# ==================== AST安全检查 ====================
|
||
|
||
class SecurityVisitor(ast.NodeVisitor):
|
||
"""
|
||
AST访问器 - 检查代码中的危险操作
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.errors: List[str] = []
|
||
self.warnings: List[str] = []
|
||
|
||
def visit_Import(self, node: ast.Import):
|
||
"""检查import语句"""
|
||
for alias in node.names:
|
||
module = alias.name.split('.')[0]
|
||
if module in DANGEROUS_MODULES:
|
||
self.errors.append(
|
||
f"🚫 禁止导入危险模块: {module} (行 {node.lineno})"
|
||
)
|
||
elif module not in SAFE_MODULES:
|
||
self.warnings.append(
|
||
f"⚠️ 导入了未知模块: {module} (行 {node.lineno})"
|
||
)
|
||
self.generic_visit(node)
|
||
|
||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||
"""检查from...import语句"""
|
||
if node.module:
|
||
module = node.module.split('.')[0]
|
||
if module in DANGEROUS_MODULES:
|
||
self.errors.append(
|
||
f"🚫 禁止从危险模块导入: {module} (行 {node.lineno})"
|
||
)
|
||
elif module not in SAFE_MODULES:
|
||
self.warnings.append(
|
||
f"⚠️ 从未知模块导入: {module} (行 {node.lineno})"
|
||
)
|
||
self.generic_visit(node)
|
||
|
||
def visit_Call(self, node: ast.Call):
|
||
"""检查函数调用"""
|
||
# 检查是否调用了危险内置函数
|
||
if isinstance(node.func, ast.Name):
|
||
func_name = node.func.id
|
||
if func_name in DANGEROUS_BUILTINS:
|
||
self.errors.append(
|
||
f"🚫 禁止调用危险函数: {func_name}() (行 {node.lineno})"
|
||
)
|
||
|
||
# 检查是否调用了open文件操作
|
||
if isinstance(node.func, ast.Attribute):
|
||
if node.func.attr == 'open':
|
||
self.errors.append(
|
||
f"🚫 禁止文件操作: open() (行 {node.lineno})"
|
||
)
|
||
|
||
self.generic_visit(node)
|
||
|
||
def visit_Delete(self, node: ast.Delete):
|
||
"""检查删除操作"""
|
||
# 不允许删除df变量本身
|
||
for target in node.targets:
|
||
if isinstance(target, ast.Name) and target.id == 'df':
|
||
self.errors.append(
|
||
f"🚫 禁止删除DataFrame变量: del df (行 {node.lineno})"
|
||
)
|
||
self.generic_visit(node)
|
||
|
||
|
||
def validate_code(code: str) -> Dict[str, Any]:
|
||
"""
|
||
AST静态代码检查
|
||
|
||
Args:
|
||
code: 待检查的Python代码
|
||
|
||
Returns:
|
||
{
|
||
"valid": bool,
|
||
"errors": List[str],
|
||
"warnings": List[str]
|
||
}
|
||
"""
|
||
try:
|
||
logger.info(f"开始AST代码检查,代码长度: {len(code)} 字符")
|
||
|
||
# 解析AST
|
||
try:
|
||
tree = ast.parse(code)
|
||
except SyntaxError as e:
|
||
return {
|
||
"valid": False,
|
||
"errors": [f"❌ 语法错误 (行 {e.lineno}): {e.msg}"],
|
||
"warnings": []
|
||
}
|
||
|
||
# 安全检查
|
||
visitor = SecurityVisitor()
|
||
visitor.visit(tree)
|
||
|
||
# 额外检查:代码中是否包含df变量
|
||
has_df = any(
|
||
isinstance(node, ast.Name) and node.id == 'df'
|
||
for node in ast.walk(tree)
|
||
)
|
||
|
||
if not has_df:
|
||
visitor.warnings.append(
|
||
"⚠️ 代码中未使用 df 变量,可能无法操作数据"
|
||
)
|
||
|
||
# 返回结果
|
||
is_valid = len(visitor.errors) == 0
|
||
|
||
logger.info(
|
||
f"AST检查完成: valid={is_valid}, "
|
||
f"errors={len(visitor.errors)}, warnings={len(visitor.warnings)}"
|
||
)
|
||
|
||
return {
|
||
"valid": is_valid,
|
||
"errors": visitor.errors,
|
||
"warnings": visitor.warnings
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"AST检查失败: {str(e)}")
|
||
return {
|
||
"valid": False,
|
||
"errors": [f"❌ 检查失败: {str(e)}"],
|
||
"warnings": []
|
||
}
|
||
|
||
|
||
# ==================== 超时处理 ====================
|
||
|
||
class TimeoutException(Exception):
|
||
"""代码执行超时异常"""
|
||
pass
|
||
|
||
|
||
def timeout_handler(signum, frame):
|
||
"""超时信号处理器"""
|
||
raise TimeoutException("代码执行超时(>30秒)")
|
||
|
||
|
||
# ==================== Pandas代码执行 ====================
|
||
|
||
def execute_pandas_code(data: List[Dict[str, Any]], code: str) -> Dict[str, Any]:
|
||
"""
|
||
在沙箱环境中执行Pandas代码
|
||
|
||
Args:
|
||
data: JSON格式的数据(数组对象)
|
||
code: Pandas代码(必须操作df变量)
|
||
|
||
Returns:
|
||
{
|
||
"success": bool,
|
||
"result_data": List[Dict], # 执行后的数据
|
||
"output": str, # 打印输出
|
||
"error": str, # 错误信息(如果失败)
|
||
"execution_time": float # 执行时间(秒)
|
||
}
|
||
"""
|
||
import time
|
||
start_time = time.time()
|
||
|
||
try:
|
||
logger.info(f"开始执行Pandas代码,数据行数: {len(data)}")
|
||
|
||
# 1. 先进行AST检查
|
||
validation = validate_code(code)
|
||
if not validation["valid"]:
|
||
return {
|
||
"success": False,
|
||
"result_data": None,
|
||
"output": "",
|
||
"error": f"代码未通过安全检查:\n" + "\n".join(validation["errors"]),
|
||
"execution_time": time.time() - start_time
|
||
}
|
||
|
||
# 2. 创建DataFrame
|
||
try:
|
||
df = pd.DataFrame(data)
|
||
logger.info(f"DataFrame创建成功: shape={df.shape}")
|
||
except Exception as e:
|
||
return {
|
||
"success": False,
|
||
"result_data": None,
|
||
"output": "",
|
||
"error": f"数据转换失败: {str(e)}",
|
||
"execution_time": time.time() - start_time
|
||
}
|
||
|
||
# 3. 准备安全的执行环境
|
||
safe_globals = {
|
||
'pd': pd,
|
||
'np': np,
|
||
'df': df,
|
||
'__builtins__': {
|
||
# 只允许安全的内置函数
|
||
'len': len,
|
||
'range': range,
|
||
'enumerate': enumerate,
|
||
'zip': zip,
|
||
'map': map,
|
||
'filter': filter,
|
||
'list': list,
|
||
'dict': dict,
|
||
'set': set,
|
||
'tuple': tuple,
|
||
'str': str,
|
||
'int': int,
|
||
'float': float,
|
||
'bool': bool,
|
||
'print': print,
|
||
'sum': sum,
|
||
'min': min,
|
||
'max': max,
|
||
'abs': abs,
|
||
'round': round,
|
||
'sorted': sorted,
|
||
'reversed': reversed,
|
||
'any': any,
|
||
'all': all,
|
||
}
|
||
}
|
||
|
||
# 4. 捕获print输出
|
||
old_stdout = sys.stdout
|
||
sys.stdout = captured_output = io.StringIO()
|
||
|
||
try:
|
||
# 5. 设置超时保护(仅在Unix系统上)
|
||
if hasattr(signal, 'SIGALRM'):
|
||
signal.signal(signal.SIGALRM, timeout_handler)
|
||
signal.alarm(EXECUTION_TIMEOUT)
|
||
|
||
# 6. 执行代码
|
||
logger.info("执行用户代码...")
|
||
exec(code, safe_globals)
|
||
|
||
# 7. 取消超时
|
||
if hasattr(signal, 'SIGALRM'):
|
||
signal.alarm(0)
|
||
|
||
# 8. 获取执行后的DataFrame
|
||
df_result = safe_globals['df']
|
||
|
||
# 9. 验证结果
|
||
if not isinstance(df_result, pd.DataFrame):
|
||
raise ValueError(
|
||
f"执行后df不是DataFrame类型,而是 {type(df_result)}"
|
||
)
|
||
|
||
logger.info(f"代码执行成功: shape={df_result.shape}")
|
||
|
||
# 10. 转换回JSON格式(NaN转为None避免JSON序列化错误)
|
||
result_data = df_result.replace({np.nan: None}).to_dict('records')
|
||
|
||
# 11. 获取print输出
|
||
output = captured_output.getvalue()
|
||
|
||
execution_time = time.time() - start_time
|
||
logger.info(f"执行完成,耗时: {execution_time:.3f}秒")
|
||
|
||
return {
|
||
"success": True,
|
||
"result_data": result_data,
|
||
"output": output.strip() if output else "",
|
||
"error": None,
|
||
"execution_time": execution_time,
|
||
"result_shape": df_result.shape
|
||
}
|
||
|
||
except TimeoutException as e:
|
||
logger.error(f"代码执行超时: {str(e)}")
|
||
return {
|
||
"success": False,
|
||
"result_data": None,
|
||
"output": captured_output.getvalue(),
|
||
"error": f"⏱️ 执行超时: 代码运行超过 {EXECUTION_TIMEOUT} 秒",
|
||
"execution_time": time.time() - start_time
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"代码执行失败: {str(e)}")
|
||
error_traceback = traceback.format_exc()
|
||
return {
|
||
"success": False,
|
||
"result_data": None,
|
||
"output": captured_output.getvalue(),
|
||
"error": f"❌ 执行错误:\n{str(e)}\n\n{error_traceback}",
|
||
"execution_time": time.time() - start_time
|
||
}
|
||
|
||
finally:
|
||
# 恢复stdout
|
||
sys.stdout = old_stdout
|
||
|
||
# 取消超时
|
||
if hasattr(signal, 'SIGALRM'):
|
||
signal.alarm(0)
|
||
|
||
except Exception as e:
|
||
logger.error(f"代码执行服务失败: {str(e)}")
|
||
return {
|
||
"success": False,
|
||
"result_data": None,
|
||
"output": "",
|
||
"error": f"服务错误: {str(e)}",
|
||
"execution_time": time.time() - start_time
|
||
}
|
||
|
||
|
||
# ==================== 测试函数 ====================
|
||
|
||
def test_dc_executor():
|
||
"""测试DC执行器"""
|
||
|
||
print("=" * 60)
|
||
print("测试1: AST安全检查 - 正常代码")
|
||
print("=" * 60)
|
||
|
||
safe_code = """
|
||
import pandas as pd
|
||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||
print(df['age_group'].value_counts())
|
||
"""
|
||
|
||
result = validate_code(safe_code)
|
||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试2: AST安全检查 - 危险代码")
|
||
print("=" * 60)
|
||
|
||
dangerous_code = """
|
||
import os
|
||
import sys
|
||
os.system('rm -rf /')
|
||
"""
|
||
|
||
result = validate_code(dangerous_code)
|
||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试3: 代码执行 - 简单操作")
|
||
print("=" * 60)
|
||
|
||
test_data = [
|
||
{"patient_id": "P001", "age": 25, "gender": "男"},
|
||
{"patient_id": "P002", "age": 65, "gender": "女"},
|
||
{"patient_id": "P003", "age": 45, "gender": "男"},
|
||
]
|
||
|
||
simple_code = """
|
||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||
print(f"处理完成,共 {len(df)} 行")
|
||
"""
|
||
|
||
result = execute_pandas_code(test_data, simple_code)
|
||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试完成!")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_dc_executor()
|
||
|
||
|