Files
AIclinicalresearch/extraction_service/operations/data_profile.py
HaHafeng 3446909ff7 feat(ssa): Complete Phase I-IV intelligent dialogue and tool system development
Phase I - Session Blackboard + READ Layer:
- SessionBlackboardService with Postgres-Only cache
- DataProfileService for data overview generation
- PicoInferenceService for LLM-driven PICO extraction
- Frontend DataContextCard and VariableDictionaryPanel
- E2E tests: 31/31 passed

Phase II - Conversation Layer LLM + Intent Router:
- ConversationService with SSE streaming
- IntentRouterService (rule-first + LLM fallback, 6 intents)
- SystemPromptService with 6-segment dynamic assembly
- TokenTruncationService for context management
- ChatHandlerService as unified chat entry
- Frontend SSAChatPane and useSSAChat hook
- E2E tests: 38/38 passed

Phase III - Method Consultation + AskUser Standardization:
- ToolRegistryService with Repository Pattern
- MethodConsultService with DecisionTable + LLM enhancement
- AskUserService with global interrupt handling
- Frontend AskUserCard component
- E2E tests: 13/13 passed

Phase IV - Dialogue-Driven Analysis + QPER Integration:
- ToolOrchestratorService (plan/execute/report)
- analysis_plan SSE event for WorkflowPlan transmission
- Dual-channel confirmation (ask_user card + workspace button)
- PICO as optional hint for LLM parsing
- E2E tests: 25/25 passed

R Statistics Service:
- 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon
- Enhanced guardrails and block helpers
- Comprehensive test suite (run_all_tools_test.js)

Documentation:
- Updated system status document (v5.9)
- Updated SSA module status and development plan (v1.8)

Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-22 18:53:39 +08:00

496 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SSA DataProfile - 数据画像生成模块 (Phase 2A → Phase I)
提供数据上传时的快速画像生成,用于 LLM 生成 SAP分析计划
高性能实现,利用 pandas 的向量化操作。
Phase I 新增:
- compute_normality_tests(df) — Shapiro-Wilk / K-S 正态性检验
- compute_complete_cases(df) — 完整病例计数
- analyze_variable_detail() — 单变量详细分析(直方图+Q-Q图数据
"""
import pandas as pd
import numpy as np
from scipy import stats as scipy_stats
from typing import List, Dict, Any, Optional
from loguru import logger
def generate_data_profile(df: pd.DataFrame, max_unique_values: int = 20) -> Dict[str, Any]:
"""
生成数据画像DataProfile
Args:
df: 输入数据框
max_unique_values: 分类变量显示的最大唯一值数量
Returns:
DataProfile JSON 结构
"""
logger.info(f"开始生成数据画像: {df.shape[0]} 行, {df.shape[1]}")
columns = []
numeric_count = 0
categorical_count = 0
datetime_count = 0
for col_name in df.columns:
col = df[col_name]
col_profile = analyze_column(col, col_name, max_unique_values)
columns.append(col_profile)
if col_profile['type'] == 'numeric':
numeric_count += 1
elif col_profile['type'] == 'categorical':
categorical_count += 1
elif col_profile['type'] == 'datetime':
datetime_count += 1
total_cells = df.shape[0] * df.shape[1]
total_missing = df.isna().sum().sum()
summary = {
'totalRows': int(df.shape[0]),
'totalColumns': int(df.shape[1]),
'numericColumns': numeric_count,
'categoricalColumns': categorical_count,
'datetimeColumns': datetime_count,
'textColumns': int(df.shape[1]) - numeric_count - categorical_count - datetime_count,
'overallMissingRate': round(total_missing / total_cells * 100, 2) if total_cells > 0 else 0,
'totalMissingCells': int(total_missing)
}
normality_tests = compute_normality_tests(df, columns)
complete_case_count = compute_complete_cases(df)
logger.info(f"数据画像生成完成: {numeric_count} 数值列, {categorical_count} 分类列, 完整病例 {complete_case_count}")
return {
'columns': columns,
'summary': summary,
'normalityTests': normality_tests,
'completeCaseCount': complete_case_count
}
def analyze_column(col: pd.Series, col_name: str, max_unique_values: int = 20) -> Dict[str, Any]:
"""
分析单个列的统计特征
Args:
col: 列数据
col_name: 列名
max_unique_values: 显示的最大唯一值数量
Returns:
列画像
"""
non_null = col.dropna()
missing_count = int(col.isna().sum())
total_count = len(col)
missing_rate = round(missing_count / total_count * 100, 2) if total_count > 0 else 0
unique_count = int(non_null.nunique())
col_type = infer_column_type(col, unique_count, total_count)
profile = {
'name': col_name,
'type': col_type,
'missingCount': missing_count,
'missingRate': missing_rate,
'uniqueCount': unique_count,
'totalCount': total_count
}
if col_type == 'numeric':
profile.update(analyze_numeric_column(non_null))
elif col_type == 'categorical':
profile.update(analyze_categorical_column(non_null, max_unique_values))
elif col_type == 'datetime':
profile.update(analyze_datetime_column(non_null))
profile['isIdLike'] = _detect_id_like(col_name, col_type, unique_count, total_count)
return profile
import re
_ID_PATTERNS = re.compile(
r'(_id|_no|_code|编号|序号|流水号|主键|record_date|visit_date|enroll_date)$|^(id|ID|Id)_|^(patient|subject|sample|record)_?id$',
re.IGNORECASE
)
def _detect_id_like(col_name: str, col_type: str, unique_count: int, total_count: int) -> bool:
"""
判断列是否为非分析变量ID / 高基数字符串 / 日期)
标记为 True 后Q 层 Context Pruning 会在注入 Prompt 前物理剔除这些列
"""
if col_type == 'datetime':
return True
if _ID_PATTERNS.search(col_name):
return True
if col_type == 'text' and total_count > 0 and unique_count / total_count > 0.95:
return True
if col_type == 'categorical' and total_count > 0 and unique_count / total_count > 0.95:
return True
return False
def infer_column_type(col: pd.Series, unique_count: int, total_count: int) -> str:
"""
推断列的数据类型
Returns:
'numeric' | 'categorical' | 'datetime' | 'text'
"""
if pd.api.types.is_datetime64_any_dtype(col):
return 'datetime'
if pd.api.types.is_numeric_dtype(col):
unique_ratio = unique_count / total_count if total_count > 0 else 0
if unique_count <= 10 and unique_ratio < 0.05:
return 'categorical'
return 'numeric'
if col.dtype == 'object' or col.dtype == 'string':
non_null = col.dropna()
if len(non_null) == 0:
return 'text'
unique_ratio = unique_count / total_count if total_count > 0 else 0
if unique_count <= 30 and unique_ratio < 0.1:
return 'categorical'
try:
pd.to_numeric(non_null, errors='raise')
return 'numeric'
except:
pass
try:
pd.to_datetime(non_null, errors='raise')
return 'datetime'
except:
pass
return 'text'
return 'text'
def analyze_numeric_column(col: pd.Series) -> Dict[str, Any]:
"""
分析数值列的统计特征
"""
if len(col) == 0:
return {}
col_numeric = pd.to_numeric(col, errors='coerce').dropna()
if len(col_numeric) == 0:
return {}
q1 = float(col_numeric.quantile(0.25))
q3 = float(col_numeric.quantile(0.75))
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
outlier_count = int(((col_numeric < lower_bound) | (col_numeric > upper_bound)).sum())
return {
'mean': round(float(col_numeric.mean()), 4),
'std': round(float(col_numeric.std()), 4),
'median': round(float(col_numeric.median()), 4),
'min': round(float(col_numeric.min()), 4),
'max': round(float(col_numeric.max()), 4),
'q1': round(q1, 4),
'q3': round(q3, 4),
'iqr': round(iqr, 4),
'outlierCount': outlier_count,
'outlierRate': round(outlier_count / len(col_numeric) * 100, 2) if len(col_numeric) > 0 else 0,
'skewness': round(float(col_numeric.skew()), 4) if len(col_numeric) >= 3 else None,
'kurtosis': round(float(col_numeric.kurtosis()), 4) if len(col_numeric) >= 4 else None
}
def analyze_categorical_column(col: pd.Series, max_values: int = 20) -> Dict[str, Any]:
"""
分析分类列的统计特征
"""
if len(col) == 0:
return {}
value_counts = col.value_counts()
total = len(col)
top_values = []
for value, count in value_counts.head(max_values).items():
top_values.append({
'value': str(value),
'count': int(count),
'percentage': round(count / total * 100, 2)
})
return {
'topValues': top_values,
'totalLevels': int(len(value_counts)),
'modeValue': str(value_counts.index[0]) if len(value_counts) > 0 else None,
'modeCount': int(value_counts.iloc[0]) if len(value_counts) > 0 else 0
}
def analyze_datetime_column(col: pd.Series) -> Dict[str, Any]:
"""
分析日期时间列的统计特征
"""
if len(col) == 0:
return {}
try:
col_dt = pd.to_datetime(col, errors='coerce').dropna()
if len(col_dt) == 0:
return {}
return {
'minDate': col_dt.min().isoformat(),
'maxDate': col_dt.max().isoformat(),
'dateRange': str(col_dt.max() - col_dt.min())
}
except:
return {}
def get_quality_score(profile: Dict[str, Any]) -> Dict[str, Any]:
"""
计算数据质量评分
Returns:
质量评分和建议
"""
summary = profile.get('summary', {})
columns = profile.get('columns', [])
score = 100.0
issues = []
recommendations = []
overall_missing_rate = summary.get('overallMissingRate', 0)
if overall_missing_rate > 20:
score -= 30
issues.append(f"整体缺失率较高 ({overall_missing_rate}%)")
recommendations.append("建议检查数据完整性,考虑缺失值处理")
elif overall_missing_rate > 10:
score -= 15
issues.append(f"整体缺失率中等 ({overall_missing_rate}%)")
recommendations.append("建议在分析前处理缺失值")
elif overall_missing_rate > 5:
score -= 5
issues.append(f"存在少量缺失 ({overall_missing_rate}%)")
for col in columns:
if col.get('outlierRate', 0) > 10:
score -= 5
issues.append(f"'{col['name']}' 存在较多异常值 ({col['outlierRate']}%)")
recommendations.append(f"建议检查列 '{col['name']}' 的异常值")
total_rows = summary.get('totalRows', 0)
if total_rows < 30:
score -= 20
issues.append(f"样本量较小 (n={total_rows})")
recommendations.append("小样本可能影响统计检验的效力")
elif total_rows < 100:
score -= 10
issues.append(f"样本量中等 (n={total_rows})")
score = max(0, min(100, score))
if score >= 80:
grade = 'A'
grade_desc = '数据质量良好'
elif score >= 60:
grade = 'B'
grade_desc = '数据质量中等'
elif score >= 40:
grade = 'C'
grade_desc = '数据质量较差'
else:
grade = 'D'
grade_desc = '数据质量很差'
return {
'score': round(score, 1),
'grade': grade,
'gradeDescription': grade_desc,
'issues': issues,
'recommendations': recommendations
}
# ────────────────────────────────────────────
# Phase I 新增函数
# ────────────────────────────────────────────
def compute_normality_tests(df: pd.DataFrame, columns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
对所有数值列执行正态性检验。
样本量 <= 5000 用 Shapiro-Wilk> 5000 降级为 Kolmogorov-Smirnov。
"""
results = []
numeric_cols = [c['name'] for c in columns if c['type'] == 'numeric']
for col_name in numeric_cols:
try:
col_data = pd.to_numeric(df[col_name], errors='coerce').dropna()
if len(col_data) < 3:
continue
if len(col_data) <= 5000:
stat, p_value = scipy_stats.shapiro(col_data)
method = 'shapiro_wilk'
else:
stat, p_value = scipy_stats.kstest(col_data, 'norm',
args=(col_data.mean(), col_data.std()))
method = 'kolmogorov_smirnov'
results.append({
'variable': col_name,
'method': method,
'statistic': round(float(stat), 4),
'pValue': round(float(p_value), 4),
'isNormal': bool(p_value >= 0.05)
})
except Exception as e:
logger.warning(f"正态性检验失败 [{col_name}]: {e}")
return results
def compute_complete_cases(df: pd.DataFrame) -> int:
"""返回无任何缺失值的完整病例数。"""
return int(df.dropna().shape[0])
def analyze_variable_detail(df: pd.DataFrame, variable_name: str,
max_bins: int = 30, max_qq_points: int = 200) -> Dict[str, Any]:
"""
单变量详细分析Phase I: get_variable_detail 工具后端)。
返回:描述统计 + 分布直方图数据 + 正态性检验 + Q-Q 图数据点。
直方图 bins 强制上限 max_binsH2 防护Q-Q 点上限 max_qq_points。
"""
if variable_name not in df.columns:
return {'success': False, 'error': f"变量 '{variable_name}' 不存在"}
col = df[variable_name]
non_null = col.dropna()
total = len(col)
missing = int(col.isna().sum())
unique_count = int(non_null.nunique())
col_type = infer_column_type(col, unique_count, total)
result: Dict[str, Any] = {
'success': True,
'variable': variable_name,
'type': col_type,
'totalCount': total,
'missingCount': missing,
'missingRate': round(missing / total * 100, 2) if total > 0 else 0,
'uniqueCount': unique_count,
}
if col_type == 'numeric':
col_numeric = pd.to_numeric(non_null, errors='coerce').dropna()
if len(col_numeric) == 0:
result['descriptive'] = {}
return result
q1 = float(col_numeric.quantile(0.25))
q3 = float(col_numeric.quantile(0.75))
iqr_val = q3 - q1
lower_bound = q1 - 1.5 * iqr_val
upper_bound = q3 + 1.5 * iqr_val
outliers = col_numeric[(col_numeric < lower_bound) | (col_numeric > upper_bound)]
result['descriptive'] = {
'mean': round(float(col_numeric.mean()), 4),
'std': round(float(col_numeric.std()), 4),
'median': round(float(col_numeric.median()), 4),
'min': round(float(col_numeric.min()), 4),
'max': round(float(col_numeric.max()), 4),
'q1': round(q1, 4),
'q3': round(q3, 4),
'iqr': round(iqr_val, 4),
'skewness': round(float(col_numeric.skew()), 4) if len(col_numeric) >= 3 else None,
'kurtosis': round(float(col_numeric.kurtosis()), 4) if len(col_numeric) >= 4 else None,
}
result['outliers'] = {
'count': int(len(outliers)),
'rate': round(len(outliers) / len(col_numeric) * 100, 2),
'lowerBound': round(lower_bound, 4),
'upperBound': round(upper_bound, 4),
}
n_bins = min(max_bins, unique_count)
hist_counts, hist_edges = np.histogram(col_numeric, bins=max(n_bins, 1))
result['histogram'] = {
'counts': [int(c) for c in hist_counts],
'edges': [round(float(e), 4) for e in hist_edges],
}
if len(col_numeric) >= 3:
try:
if len(col_numeric) <= 5000:
stat, p_val = scipy_stats.shapiro(col_numeric)
method = 'shapiro_wilk'
else:
stat, p_val = scipy_stats.kstest(col_numeric, 'norm',
args=(col_numeric.mean(), col_numeric.std()))
method = 'kolmogorov_smirnov'
result['normalityTest'] = {
'method': method,
'statistic': round(float(stat), 4),
'pValue': round(float(p_val), 4),
'isNormal': bool(p_val >= 0.05),
}
except Exception:
result['normalityTest'] = None
sorted_data = np.sort(col_numeric.values)
n = len(sorted_data)
if n > max_qq_points:
indices = np.linspace(0, n - 1, max_qq_points, dtype=int)
sampled = sorted_data[indices]
else:
sampled = sorted_data
theoretical = scipy_stats.norm.ppf(
np.linspace(1 / (len(sampled) + 1), len(sampled) / (len(sampled) + 1), len(sampled))
)
result['qqPlot'] = {
'theoretical': [round(float(t), 4) for t in theoretical],
'observed': [round(float(o), 4) for o in sampled],
}
elif col_type == 'categorical':
value_counts = non_null.value_counts()
total_non_null = len(non_null)
result['distribution'] = [
{
'value': str(val),
'count': int(cnt),
'percentage': round(cnt / total_non_null * 100, 2)
}
for val, cnt in value_counts.items()
]
result['descriptive'] = {
'totalLevels': int(len(value_counts)),
'modeValue': str(value_counts.index[0]) if len(value_counts) > 0 else None,
'modeCount': int(value_counts.iloc[0]) if len(value_counts) > 0 else 0,
}
return result