feat(dc/tool-c): 完成AI代码生成服务(Day 3 MVP)
核心功能: - 新增AICodeService(550行):AI代码生成核心服务 - 新增AIController(257行):4个API端点 - 新增dc_tool_c_ai_history表:存储对话历史 - 实现自我修正机制:最多3次智能重试 - 集成LLMFactory:复用通用能力层 - 10个Few-shot示例:覆盖Level 1-4场景 技术优化: - 修复NaN序列化问题(Python端转None) - 修复数据传递问题(从Session获取真实数据) - 优化System Prompt(明确环境信息) - 调整Few-shot示例(移除import语句) 测试结果: - 通过率:9/11(81.8%) 达到MVP标准 - 成功场景:缺失值处理、编码、分箱、BMI、筛选、填补、统计、分类 - 待优化:数值清洗、智能去重(已记录技术债务TD-C-006) API端点: - POST /api/v1/dc/tool-c/ai/generate(生成代码) - POST /api/v1/dc/tool-c/ai/execute(执行代码) - POST /api/v1/dc/tool-c/ai/process(生成并执行,一步到位) - GET /api/v1/dc/tool-c/ai/history/:sessionId(对话历史) 文档更新: - 新增Day 3开发完成总结(770行) - 新增复杂场景优化技术债务(TD-C-006) - 更新工具C当前状态文档 - 更新技术债务清单 影响范围: - backend/src/modules/dc/tool-c/*(新增2个文件,更新1个文件) - backend/scripts/create-tool-c-ai-history-table.mjs(新增) - backend/prisma/schema.prisma(新增DcToolCAiHistory模型) - extraction_service/services/dc_executor.py(NaN序列化修复) - docs/03-业务模块/DC-数据清洗整理/*(5份文档更新) Breaking Changes: 无 总代码行数:+950行 Refs: #Tool-C-Day3
This commit is contained in:
@@ -12,6 +12,8 @@
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
import os
|
||||
@@ -58,6 +60,19 @@ from services.nougat_extractor import check_nougat_available, get_nougat_info
|
||||
from services.file_utils import detect_file_type, cleanup_temp_file
|
||||
from services.docx_extractor import extract_docx_mammoth, validate_docx_file
|
||||
from services.txt_extractor import extract_txt, validate_txt_file
|
||||
from services.dc_executor import validate_code, execute_pandas_code
|
||||
|
||||
|
||||
# ==================== Pydantic Models ====================
|
||||
|
||||
class ValidateCodeRequest(BaseModel):
|
||||
"""代码验证请求模型"""
|
||||
code: str
|
||||
|
||||
class ExecuteCodeRequest(BaseModel):
|
||||
"""代码执行请求模型"""
|
||||
data: List[Dict[str, Any]]
|
||||
code: str
|
||||
|
||||
|
||||
# ==================== API路由 ====================
|
||||
@@ -484,6 +499,99 @@ async def extract_document(
|
||||
)
|
||||
|
||||
|
||||
# ==================== DC工具C - 代码执行接口 ====================
|
||||
|
||||
@app.post("/api/dc/validate")
|
||||
async def validate_pandas_code(request: ValidateCodeRequest):
|
||||
"""
|
||||
DC工具C - Pandas代码安全验证接口
|
||||
|
||||
Args:
|
||||
request: ValidateCodeRequest
|
||||
- code: str # 待验证的Pandas代码
|
||||
|
||||
Returns:
|
||||
{
|
||||
"valid": bool,
|
||||
"errors": List[str],
|
||||
"warnings": List[str]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始验证Pandas代码,长度: {len(request.code)} 字符")
|
||||
|
||||
# 执行AST安全检查
|
||||
result = validate_code(request.code)
|
||||
|
||||
logger.info(
|
||||
f"代码验证完成: valid={result['valid']}, "
|
||||
f"errors={len(result['errors'])}, warnings={len(result['warnings'])}"
|
||||
)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"代码验证失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"验证失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/dc/execute")
|
||||
async def execute_pandas_code_endpoint(request: ExecuteCodeRequest):
|
||||
"""
|
||||
DC工具C - Pandas代码执行接口
|
||||
|
||||
Args:
|
||||
request: ExecuteCodeRequest
|
||||
- data: List[Dict] # JSON格式的数据(数组对象)
|
||||
- code: str # Pandas代码(操作df变量)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"result_data": List[Dict], # 执行后的数据
|
||||
"output": str, # 打印输出
|
||||
"error": str, # 错误信息(如果失败)
|
||||
"execution_time": float, # 执行时间(秒)
|
||||
"result_shape": [rows, cols] # 结果形状
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"开始执行Pandas代码: "
|
||||
f"数据行数={len(request.data)}, 代码长度={len(request.code)} 字符"
|
||||
)
|
||||
|
||||
# 执行代码
|
||||
result = execute_pandas_code(request.data, request.code)
|
||||
|
||||
if result["success"]:
|
||||
logger.info(
|
||||
f"代码执行成功: "
|
||||
f"结果shape={result.get('result_shape')}, "
|
||||
f"耗时={result['execution_time']:.3f}秒"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"代码执行失败: {result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"代码执行接口失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"处理失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 启动配置 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
63
extraction_service/quick_test.py
Normal file
63
extraction_service/quick_test.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""快速测试DC API"""
|
||||
import requests
|
||||
import json
|
||||
|
||||
print("=" * 60)
|
||||
print("DC工具C - Python微服务快速测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试1: 代码验证(正常代码)
|
||||
print("\n【测试1】代码验证 - 正常代码")
|
||||
try:
|
||||
r = requests.post("http://localhost:8000/api/dc/validate", json={"code": "df['x'] = 1"}, timeout=5)
|
||||
print(f" 状态码: {r.status_code}")
|
||||
if r.status_code == 200:
|
||||
result = r.json()
|
||||
print(f" valid={result['valid']}, errors={result['errors']}, warnings={result['warnings']}")
|
||||
print(f" ✅ 测试1通过")
|
||||
else:
|
||||
print(f" ❌ 测试1失败: {r.text}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 测试1异常: {e}")
|
||||
|
||||
# 测试2: 代码验证(危险代码)
|
||||
print("\n【测试2】代码验证 - 危险代码(应被拦截)")
|
||||
try:
|
||||
r = requests.post("http://localhost:8000/api/dc/validate", json={"code": "import os"}, timeout=5)
|
||||
print(f" 状态码: {r.status_code}")
|
||||
if r.status_code == 200:
|
||||
result = r.json()
|
||||
print(f" valid={result['valid']}, errors数量={len(result.get('errors',[]))}")
|
||||
if not result['valid'] and len(result.get('errors',[])) > 0:
|
||||
print(f" ✅ 测试2通过(危险代码被拦截)")
|
||||
else:
|
||||
print(f" ❌ 测试2失败(危险代码未被拦截)")
|
||||
else:
|
||||
print(f" ❌ 测试2失败: {r.text}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 测试2异常: {e}")
|
||||
|
||||
# 测试3: 代码执行
|
||||
print("\n【测试3】代码执行 - 简单Pandas操作")
|
||||
try:
|
||||
data = [{"age": 25}, {"age": 65}, {"age": 45}]
|
||||
code = "df['old'] = df['age'] > 60"
|
||||
r = requests.post("http://localhost:8000/api/dc/execute", json={"data": data, "code": code}, timeout=10)
|
||||
print(f" 状态码: {r.status_code}")
|
||||
if r.status_code == 200:
|
||||
result = r.json()
|
||||
print(f" success={result.get('success')}, 执行时间={result.get('execution_time',0):.3f}秒")
|
||||
if result.get('success'):
|
||||
print(f" 结果数据: {result['result_data']}")
|
||||
print(f" ✅ 测试3通过(代码成功执行)")
|
||||
else:
|
||||
print(f" ❌ 测试3失败: {result.get('error')}")
|
||||
else:
|
||||
print(f" ❌ 测试3失败: {r.text}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 测试3异常: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 Day 1 Python服务测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
427
extraction_service/services/dc_executor.py
Normal file
427
extraction_service/services/dc_executor.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""
|
||||
DC工具C - Pandas代码执行服务
|
||||
|
||||
功能:
|
||||
- AST静态代码检查(安全验证)
|
||||
- Pandas代码执行(沙箱环境)
|
||||
- 危险模块拦截
|
||||
- 超时保护
|
||||
"""
|
||||
|
||||
import ast
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import io
|
||||
import traceback
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from loguru import logger
|
||||
import signal
|
||||
import json
|
||||
|
||||
|
||||
# ==================== 配置常量 ====================
|
||||
|
||||
# 危险模块黑名单(禁止导入)
|
||||
DANGEROUS_MODULES = {
|
||||
'os', 'sys', 'subprocess', 'shutil', 'glob',
|
||||
'socket', 'urllib', 'requests', 'http',
|
||||
'pickle', 'shelve', 'dbm',
|
||||
'importlib', '__import__',
|
||||
'eval', 'exec', 'compile',
|
||||
'open', 'input', 'file',
|
||||
}
|
||||
|
||||
# 危险内置函数黑名单
|
||||
DANGEROUS_BUILTINS = {
|
||||
'eval', 'exec', 'compile', '__import__',
|
||||
'open', 'input', 'file',
|
||||
'getattr', 'setattr', 'delattr',
|
||||
'globals', 'locals', 'vars',
|
||||
}
|
||||
|
||||
# 允许的安全模块
|
||||
SAFE_MODULES = {
|
||||
'pandas', 'numpy', 'pd', 'np',
|
||||
'datetime', 'math', 'statistics',
|
||||
'json', 're', 'collections',
|
||||
}
|
||||
|
||||
# 代码执行超时时间(秒)
|
||||
EXECUTION_TIMEOUT = 30
|
||||
|
||||
|
||||
# ==================== AST安全检查 ====================
|
||||
|
||||
class SecurityVisitor(ast.NodeVisitor):
|
||||
"""
|
||||
AST访问器 - 检查代码中的危险操作
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.errors: List[str] = []
|
||||
self.warnings: List[str] = []
|
||||
|
||||
def visit_Import(self, node: ast.Import):
|
||||
"""检查import语句"""
|
||||
for alias in node.names:
|
||||
module = alias.name.split('.')[0]
|
||||
if module in DANGEROUS_MODULES:
|
||||
self.errors.append(
|
||||
f"🚫 禁止导入危险模块: {module} (行 {node.lineno})"
|
||||
)
|
||||
elif module not in SAFE_MODULES:
|
||||
self.warnings.append(
|
||||
f"⚠️ 导入了未知模块: {module} (行 {node.lineno})"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
"""检查from...import语句"""
|
||||
if node.module:
|
||||
module = node.module.split('.')[0]
|
||||
if module in DANGEROUS_MODULES:
|
||||
self.errors.append(
|
||||
f"🚫 禁止从危险模块导入: {module} (行 {node.lineno})"
|
||||
)
|
||||
elif module not in SAFE_MODULES:
|
||||
self.warnings.append(
|
||||
f"⚠️ 从未知模块导入: {module} (行 {node.lineno})"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
"""检查函数调用"""
|
||||
# 检查是否调用了危险内置函数
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_name = node.func.id
|
||||
if func_name in DANGEROUS_BUILTINS:
|
||||
self.errors.append(
|
||||
f"🚫 禁止调用危险函数: {func_name}() (行 {node.lineno})"
|
||||
)
|
||||
|
||||
# 检查是否调用了open文件操作
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
if node.func.attr == 'open':
|
||||
self.errors.append(
|
||||
f"🚫 禁止文件操作: open() (行 {node.lineno})"
|
||||
)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Delete(self, node: ast.Delete):
|
||||
"""检查删除操作"""
|
||||
# 不允许删除df变量本身
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == 'df':
|
||||
self.errors.append(
|
||||
f"🚫 禁止删除DataFrame变量: del df (行 {node.lineno})"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def validate_code(code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
AST静态代码检查
|
||||
|
||||
Args:
|
||||
code: 待检查的Python代码
|
||||
|
||||
Returns:
|
||||
{
|
||||
"valid": bool,
|
||||
"errors": List[str],
|
||||
"warnings": List[str]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始AST代码检查,代码长度: {len(code)} 字符")
|
||||
|
||||
# 解析AST
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": [f"❌ 语法错误 (行 {e.lineno}): {e.msg}"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# 安全检查
|
||||
visitor = SecurityVisitor()
|
||||
visitor.visit(tree)
|
||||
|
||||
# 额外检查:代码中是否包含df变量
|
||||
has_df = any(
|
||||
isinstance(node, ast.Name) and node.id == 'df'
|
||||
for node in ast.walk(tree)
|
||||
)
|
||||
|
||||
if not has_df:
|
||||
visitor.warnings.append(
|
||||
"⚠️ 代码中未使用 df 变量,可能无法操作数据"
|
||||
)
|
||||
|
||||
# 返回结果
|
||||
is_valid = len(visitor.errors) == 0
|
||||
|
||||
logger.info(
|
||||
f"AST检查完成: valid={is_valid}, "
|
||||
f"errors={len(visitor.errors)}, warnings={len(visitor.warnings)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"valid": is_valid,
|
||||
"errors": visitor.errors,
|
||||
"warnings": visitor.warnings
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AST检查失败: {str(e)}")
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": [f"❌ 检查失败: {str(e)}"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
|
||||
# ==================== 超时处理 ====================
|
||||
|
||||
class TimeoutException(Exception):
|
||||
"""代码执行超时异常"""
|
||||
pass
|
||||
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
"""超时信号处理器"""
|
||||
raise TimeoutException("代码执行超时(>30秒)")
|
||||
|
||||
|
||||
# ==================== Pandas代码执行 ====================
|
||||
|
||||
def execute_pandas_code(data: List[Dict[str, Any]], code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
在沙箱环境中执行Pandas代码
|
||||
|
||||
Args:
|
||||
data: JSON格式的数据(数组对象)
|
||||
code: Pandas代码(必须操作df变量)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"result_data": List[Dict], # 执行后的数据
|
||||
"output": str, # 打印输出
|
||||
"error": str, # 错误信息(如果失败)
|
||||
"execution_time": float # 执行时间(秒)
|
||||
}
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"开始执行Pandas代码,数据行数: {len(data)}")
|
||||
|
||||
# 1. 先进行AST检查
|
||||
validation = validate_code(code)
|
||||
if not validation["valid"]:
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": "",
|
||||
"error": f"代码未通过安全检查:\n" + "\n".join(validation["errors"]),
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
# 2. 创建DataFrame
|
||||
try:
|
||||
df = pd.DataFrame(data)
|
||||
logger.info(f"DataFrame创建成功: shape={df.shape}")
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": "",
|
||||
"error": f"数据转换失败: {str(e)}",
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
# 3. 准备安全的执行环境
|
||||
safe_globals = {
|
||||
'pd': pd,
|
||||
'np': np,
|
||||
'df': df,
|
||||
'__builtins__': {
|
||||
# 只允许安全的内置函数
|
||||
'len': len,
|
||||
'range': range,
|
||||
'enumerate': enumerate,
|
||||
'zip': zip,
|
||||
'map': map,
|
||||
'filter': filter,
|
||||
'list': list,
|
||||
'dict': dict,
|
||||
'set': set,
|
||||
'tuple': tuple,
|
||||
'str': str,
|
||||
'int': int,
|
||||
'float': float,
|
||||
'bool': bool,
|
||||
'print': print,
|
||||
'sum': sum,
|
||||
'min': min,
|
||||
'max': max,
|
||||
'abs': abs,
|
||||
'round': round,
|
||||
'sorted': sorted,
|
||||
'reversed': reversed,
|
||||
'any': any,
|
||||
'all': all,
|
||||
}
|
||||
}
|
||||
|
||||
# 4. 捕获print输出
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = captured_output = io.StringIO()
|
||||
|
||||
try:
|
||||
# 5. 设置超时保护(仅在Unix系统上)
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(EXECUTION_TIMEOUT)
|
||||
|
||||
# 6. 执行代码
|
||||
logger.info("执行用户代码...")
|
||||
exec(code, safe_globals)
|
||||
|
||||
# 7. 取消超时
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
signal.alarm(0)
|
||||
|
||||
# 8. 获取执行后的DataFrame
|
||||
df_result = safe_globals['df']
|
||||
|
||||
# 9. 验证结果
|
||||
if not isinstance(df_result, pd.DataFrame):
|
||||
raise ValueError(
|
||||
f"执行后df不是DataFrame类型,而是 {type(df_result)}"
|
||||
)
|
||||
|
||||
logger.info(f"代码执行成功: shape={df_result.shape}")
|
||||
|
||||
# 10. 转换回JSON格式(NaN转为None避免JSON序列化错误)
|
||||
result_data = df_result.replace({np.nan: None}).to_dict('records')
|
||||
|
||||
# 11. 获取print输出
|
||||
output = captured_output.getvalue()
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"执行完成,耗时: {execution_time:.3f}秒")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result_data": result_data,
|
||||
"output": output.strip() if output else "",
|
||||
"error": None,
|
||||
"execution_time": execution_time,
|
||||
"result_shape": df_result.shape
|
||||
}
|
||||
|
||||
except TimeoutException as e:
|
||||
logger.error(f"代码执行超时: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": captured_output.getvalue(),
|
||||
"error": f"⏱️ 执行超时: 代码运行超过 {EXECUTION_TIMEOUT} 秒",
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"代码执行失败: {str(e)}")
|
||||
error_traceback = traceback.format_exc()
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": captured_output.getvalue(),
|
||||
"error": f"❌ 执行错误:\n{str(e)}\n\n{error_traceback}",
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
finally:
|
||||
# 恢复stdout
|
||||
sys.stdout = old_stdout
|
||||
|
||||
# 取消超时
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
signal.alarm(0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"代码执行服务失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"result_data": None,
|
||||
"output": "",
|
||||
"error": f"服务错误: {str(e)}",
|
||||
"execution_time": time.time() - start_time
|
||||
}
|
||||
|
||||
|
||||
# ==================== 测试函数 ====================
|
||||
|
||||
def test_dc_executor():
|
||||
"""测试DC执行器"""
|
||||
|
||||
print("=" * 60)
|
||||
print("测试1: AST安全检查 - 正常代码")
|
||||
print("=" * 60)
|
||||
|
||||
safe_code = """
|
||||
import pandas as pd
|
||||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||||
print(df['age_group'].value_counts())
|
||||
"""
|
||||
|
||||
result = validate_code(safe_code)
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试2: AST安全检查 - 危险代码")
|
||||
print("=" * 60)
|
||||
|
||||
dangerous_code = """
|
||||
import os
|
||||
import sys
|
||||
os.system('rm -rf /')
|
||||
"""
|
||||
|
||||
result = validate_code(dangerous_code)
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试3: 代码执行 - 简单操作")
|
||||
print("=" * 60)
|
||||
|
||||
test_data = [
|
||||
{"patient_id": "P001", "age": 25, "gender": "男"},
|
||||
{"patient_id": "P002", "age": 65, "gender": "女"},
|
||||
{"patient_id": "P003", "age": 45, "gender": "男"},
|
||||
]
|
||||
|
||||
simple_code = """
|
||||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||||
print(f"处理完成,共 {len(df)} 行")
|
||||
"""
|
||||
|
||||
result = execute_pandas_code(test_data, simple_code)
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dc_executor()
|
||||
|
||||
|
||||
281
extraction_service/test_dc_api.py
Normal file
281
extraction_service/test_dc_api.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
DC工具C - API测试脚本
|
||||
|
||||
测试项:
|
||||
1. 健康检查 (GET /api/health)
|
||||
2. AST安全检查 - 正常代码
|
||||
3. AST安全检查 - 危险代码
|
||||
4. Pandas代码执行 - 简单场景
|
||||
5. Pandas代码执行 - 医疗数据清洗场景
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
|
||||
def print_test_header(title: str):
|
||||
"""打印测试标题"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" {title}")
|
||||
print("=" * 70)
|
||||
|
||||
def print_result(response: requests.Response):
|
||||
"""打印响应结果"""
|
||||
print(f"\n状态码: {response.status_code}")
|
||||
print(f"响应内容:")
|
||||
try:
|
||||
result = response.json()
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
except:
|
||||
print(response.text)
|
||||
|
||||
def test_health_check():
|
||||
"""测试1: 健康检查"""
|
||||
print_test_header("测试1: 健康检查")
|
||||
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/api/health", timeout=5)
|
||||
print_result(response)
|
||||
|
||||
if response.status_code == 200:
|
||||
print("\n✅ 健康检查通过")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 健康检查失败")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ 健康检查异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_validate_safe_code():
|
||||
"""测试2: AST安全检查 - 正常代码"""
|
||||
print_test_header("测试2: AST安全检查 - 正常代码")
|
||||
|
||||
safe_code = """
|
||||
import pandas as pd
|
||||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||||
print(df['age_group'].value_counts())
|
||||
"""
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/dc/validate",
|
||||
json={"code": safe_code},
|
||||
timeout=5
|
||||
)
|
||||
print_result(response)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get("valid"):
|
||||
print("\n✅ 正常代码验证通过(valid=True)")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 正常代码被误判为危险")
|
||||
return False
|
||||
else:
|
||||
print("\n❌ API调用失败")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_validate_dangerous_code():
|
||||
"""测试3: AST安全检查 - 危险代码"""
|
||||
print_test_header("测试3: AST安全检查 - 危险代码(应该被拦截)")
|
||||
|
||||
dangerous_code = """
|
||||
import os
|
||||
import sys
|
||||
os.system('echo "危险操作"')
|
||||
eval('print("evil code")')
|
||||
"""
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/dc/validate",
|
||||
json={"code": dangerous_code},
|
||||
timeout=5
|
||||
)
|
||||
print_result(response)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if not result.get("valid") and len(result.get("errors", [])) > 0:
|
||||
print("\n✅ 危险代码成功拦截(valid=False, 有错误信息)")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 危险代码未被拦截!")
|
||||
return False
|
||||
else:
|
||||
print("\n❌ API调用失败")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_execute_simple_code():
|
||||
"""测试4: Pandas代码执行 - 简单场景"""
|
||||
print_test_header("测试4: Pandas代码执行 - 简单场景")
|
||||
|
||||
test_data = [
|
||||
{"patient_id": "P001", "age": 25, "gender": "男"},
|
||||
{"patient_id": "P002", "age": 65, "gender": "女"},
|
||||
{"patient_id": "P003", "age": 45, "gender": "男"},
|
||||
{"patient_id": "P004", "age": 70, "gender": "女"},
|
||||
]
|
||||
|
||||
simple_code = """
|
||||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||||
print(f"数据处理完成,共 {len(df)} 行")
|
||||
print(df['age_group'].value_counts())
|
||||
"""
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/dc/execute",
|
||||
json={"data": test_data, "code": simple_code},
|
||||
timeout=10
|
||||
)
|
||||
print_result(response)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get("success"):
|
||||
result_data = result.get("result_data", [])
|
||||
print(f"\n结果数据行数: {len(result_data)}")
|
||||
print(f"执行时间: {result.get('execution_time', 0):.3f}秒")
|
||||
|
||||
# 验证新列是否添加
|
||||
if len(result_data) > 0 and 'age_group' in result_data[0]:
|
||||
print("\n✅ 简单代码执行成功(新增列 age_group)")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 代码执行成功但结果不正确")
|
||||
return False
|
||||
else:
|
||||
print(f"\n❌ 代码执行失败: {result.get('error')}")
|
||||
return False
|
||||
else:
|
||||
print("\n❌ API调用失败")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_execute_medical_cleaning():
|
||||
"""测试5: Pandas代码执行 - 医疗数据清洗场景"""
|
||||
print_test_header("测试5: Pandas代码执行 - 医疗数据清洗场景")
|
||||
|
||||
# 模拟医疗数据
|
||||
medical_data = [
|
||||
{"patient_id": "P001", "age": 25, "gender": "男", "sbp": 120, "dbp": 80},
|
||||
{"patient_id": "P002", "age": 65, "gender": "女", "sbp": 150, "dbp": 95},
|
||||
{"patient_id": "P003", "age": 45, "gender": "男", "sbp": 135, "dbp": 85},
|
||||
{"patient_id": "P004", "age": None, "gender": "女", "sbp": 160, "dbp": 100},
|
||||
{"patient_id": "P005", "age": 200, "gender": "男", "sbp": 110, "dbp": 70},
|
||||
]
|
||||
|
||||
# 复杂的医疗数据清洗代码
|
||||
medical_code = """
|
||||
import numpy as np
|
||||
|
||||
# 1. 清理异常年龄值(>120视为异常)
|
||||
df['age'] = df['age'].apply(lambda x: np.nan if x is None or x > 120 else x)
|
||||
|
||||
# 2. 计算血压状态(收缩压 >= 140 或舒张压 >= 90 为高血压)
|
||||
df['hypertension'] = df.apply(
|
||||
lambda row: '高血压' if row['sbp'] >= 140 or row['dbp'] >= 90 else '正常',
|
||||
axis=1
|
||||
)
|
||||
|
||||
# 3. 统计结果
|
||||
print(f"总样本数: {len(df)}")
|
||||
print(f"年龄缺失数: {df['age'].isna().sum()}")
|
||||
print(f"高血压人数: {(df['hypertension'] == '高血压').sum()}")
|
||||
"""
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/api/dc/execute",
|
||||
json={"data": medical_data, "code": medical_code},
|
||||
timeout=10
|
||||
)
|
||||
print_result(response)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get("success"):
|
||||
result_data = result.get("result_data", [])
|
||||
print(f"\n结果数据行数: {len(result_data)}")
|
||||
print(f"执行时间: {result.get('execution_time', 0):.3f}秒")
|
||||
|
||||
# 验证新列是否添加
|
||||
if len(result_data) > 0 and 'hypertension' in result_data[0]:
|
||||
# 验证数据清洗逻辑
|
||||
hypertension_count = sum(
|
||||
1 for row in result_data
|
||||
if row.get('hypertension') == '高血压'
|
||||
)
|
||||
print(f"高血压人数: {hypertension_count}")
|
||||
|
||||
print("\n✅ 医疗数据清洗场景执行成功")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 代码执行成功但结果不正确")
|
||||
return False
|
||||
else:
|
||||
print(f"\n❌ 代码执行失败: {result.get('error')}")
|
||||
return False
|
||||
else:
|
||||
print("\n❌ API调用失败")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试异常: {str(e)}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "🚀" * 35)
|
||||
print(" DC工具C - Python微服务API测试")
|
||||
print("🚀" * 35)
|
||||
|
||||
# 运行所有测试
|
||||
results = {
|
||||
"健康检查": test_health_check(),
|
||||
"AST检查-正常代码": test_validate_safe_code(),
|
||||
"AST检查-危险代码": test_validate_dangerous_code(),
|
||||
"代码执行-简单场景": test_execute_simple_code(),
|
||||
"代码执行-医疗清洗": test_execute_medical_cleaning(),
|
||||
}
|
||||
|
||||
# 汇总结果
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试结果汇总")
|
||||
print("=" * 70)
|
||||
|
||||
for test_name, passed in results.items():
|
||||
status = "✅ 通过" if passed else "❌ 失败"
|
||||
print(f"{test_name:20s}: {status}")
|
||||
|
||||
total = len(results)
|
||||
passed = sum(1 for r in results.values() if r)
|
||||
success_rate = (passed / total * 100) if total > 0 else 0
|
||||
|
||||
print("\n" + "-" * 70)
|
||||
print(f"总计: {passed}/{total} 通过 ({success_rate:.1f}%)")
|
||||
print("-" * 70)
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 所有测试通过!Day 1 Python服务开发完成!")
|
||||
else:
|
||||
print(f"\n⚠️ 有 {total - passed} 个测试失败,请检查")
|
||||
|
||||
print("\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
47
extraction_service/test_execute_simple.py
Normal file
47
extraction_service/test_execute_simple.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""简单的代码执行测试"""
|
||||
import requests
|
||||
import json
|
||||
|
||||
# 测试数据
|
||||
test_data = [
|
||||
{"patient_id": "P001", "age": 25, "gender": "男"},
|
||||
{"patient_id": "P002", "age": 65, "gender": "女"},
|
||||
{"patient_id": "P003", "age": 45, "gender": "男"},
|
||||
]
|
||||
|
||||
# 测试代码
|
||||
test_code = """
|
||||
df['age_group'] = df['age'].apply(lambda x: '老年' if x > 60 else '非老年')
|
||||
print(f"处理完成,共 {len(df)} 行")
|
||||
"""
|
||||
|
||||
print("=" * 60)
|
||||
print("测试: Pandas代码执行")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
"http://localhost:8000/api/dc/execute",
|
||||
json={"data": test_data, "code": test_code},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
print(f"\n状态码: {response.status_code}")
|
||||
result = response.json()
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
if result.get("success"):
|
||||
print("\n✅ 代码执行成功!")
|
||||
print(f"结果数据: {len(result.get('result_data', []))} 行")
|
||||
print(f"执行时间: {result.get('execution_time', 0):.3f}秒")
|
||||
print(f"\n打印输出:\n{result.get('output', '')}")
|
||||
print(f"\n结果数据示例:")
|
||||
for row in result.get('result_data', [])[:3]:
|
||||
print(f" {row}")
|
||||
else:
|
||||
print(f"\n❌ 代码执行失败: {result.get('error')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试异常: {str(e)}")
|
||||
|
||||
|
||||
27
extraction_service/test_module.py
Normal file
27
extraction_service/test_module.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""测试dc_executor模块"""
|
||||
print("测试dc_executor模块导入...")
|
||||
try:
|
||||
from services.dc_executor import validate_code, execute_pandas_code
|
||||
print("✅ 模块导入成功")
|
||||
|
||||
# 测试验证功能
|
||||
print("\n测试validate_code...")
|
||||
result = validate_code("df['x'] = 1")
|
||||
print(f"✅ validate_code成功: {result}")
|
||||
|
||||
# 测试执行功能
|
||||
print("\n测试execute_pandas_code...")
|
||||
test_data = [{"age": 25}, {"age": 65}]
|
||||
result = execute_pandas_code(test_data, "df['old'] = df['age'] > 60")
|
||||
print(f"✅ execute_pandas_code成功: success={result['success']}")
|
||||
if result['success']:
|
||||
print(f" 结果: {result['result_data']}")
|
||||
|
||||
print("\n🎉 所有模块测试通过!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user