Phase I - Session Blackboard + READ Layer: - SessionBlackboardService with Postgres-Only cache - DataProfileService for data overview generation - PicoInferenceService for LLM-driven PICO extraction - Frontend DataContextCard and VariableDictionaryPanel - E2E tests: 31/31 passed Phase II - Conversation Layer LLM + Intent Router: - ConversationService with SSE streaming - IntentRouterService (rule-first + LLM fallback, 6 intents) - SystemPromptService with 6-segment dynamic assembly - TokenTruncationService for context management - ChatHandlerService as unified chat entry - Frontend SSAChatPane and useSSAChat hook - E2E tests: 38/38 passed Phase III - Method Consultation + AskUser Standardization: - ToolRegistryService with Repository Pattern - MethodConsultService with DecisionTable + LLM enhancement - AskUserService with global interrupt handling - Frontend AskUserCard component - E2E tests: 13/13 passed Phase IV - Dialogue-Driven Analysis + QPER Integration: - ToolOrchestratorService (plan/execute/report) - analysis_plan SSE event for WorkflowPlan transmission - Dual-channel confirmation (ask_user card + workspace button) - PICO as optional hint for LLM parsing - E2E tests: 25/25 passed R Statistics Service: - 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon - Enhanced guardrails and block helpers - Comprehensive test suite (run_all_tools_test.js) Documentation: - Updated system status document (v5.9) - Updated SSA module status and development plan (v1.8) Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25) Co-authored-by: Cursor <cursoragent@cursor.com>
496 lines
17 KiB
Python
496 lines
17 KiB
Python
"""
|
||
SSA DataProfile - 数据画像生成模块 (Phase 2A → Phase I)
|
||
|
||
提供数据上传时的快速画像生成,用于 LLM 生成 SAP(分析计划)。
|
||
高性能实现,利用 pandas 的向量化操作。
|
||
|
||
Phase I 新增:
|
||
- compute_normality_tests(df) — Shapiro-Wilk / K-S 正态性检验
|
||
- compute_complete_cases(df) — 完整病例计数
|
||
- analyze_variable_detail() — 单变量详细分析(直方图+Q-Q图数据)
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from scipy import stats as scipy_stats
|
||
from typing import List, Dict, Any, Optional
|
||
from loguru import logger
|
||
|
||
|
||
def generate_data_profile(df: pd.DataFrame, max_unique_values: int = 20) -> Dict[str, Any]:
|
||
"""
|
||
生成数据画像(DataProfile)
|
||
|
||
Args:
|
||
df: 输入数据框
|
||
max_unique_values: 分类变量显示的最大唯一值数量
|
||
|
||
Returns:
|
||
DataProfile JSON 结构
|
||
"""
|
||
logger.info(f"开始生成数据画像: {df.shape[0]} 行, {df.shape[1]} 列")
|
||
|
||
columns = []
|
||
numeric_count = 0
|
||
categorical_count = 0
|
||
datetime_count = 0
|
||
|
||
for col_name in df.columns:
|
||
col = df[col_name]
|
||
col_profile = analyze_column(col, col_name, max_unique_values)
|
||
columns.append(col_profile)
|
||
|
||
if col_profile['type'] == 'numeric':
|
||
numeric_count += 1
|
||
elif col_profile['type'] == 'categorical':
|
||
categorical_count += 1
|
||
elif col_profile['type'] == 'datetime':
|
||
datetime_count += 1
|
||
|
||
total_cells = df.shape[0] * df.shape[1]
|
||
total_missing = df.isna().sum().sum()
|
||
|
||
summary = {
|
||
'totalRows': int(df.shape[0]),
|
||
'totalColumns': int(df.shape[1]),
|
||
'numericColumns': numeric_count,
|
||
'categoricalColumns': categorical_count,
|
||
'datetimeColumns': datetime_count,
|
||
'textColumns': int(df.shape[1]) - numeric_count - categorical_count - datetime_count,
|
||
'overallMissingRate': round(total_missing / total_cells * 100, 2) if total_cells > 0 else 0,
|
||
'totalMissingCells': int(total_missing)
|
||
}
|
||
|
||
normality_tests = compute_normality_tests(df, columns)
|
||
complete_case_count = compute_complete_cases(df)
|
||
|
||
logger.info(f"数据画像生成完成: {numeric_count} 数值列, {categorical_count} 分类列, 完整病例 {complete_case_count}")
|
||
|
||
return {
|
||
'columns': columns,
|
||
'summary': summary,
|
||
'normalityTests': normality_tests,
|
||
'completeCaseCount': complete_case_count
|
||
}
|
||
|
||
|
||
def analyze_column(col: pd.Series, col_name: str, max_unique_values: int = 20) -> Dict[str, Any]:
|
||
"""
|
||
分析单个列的统计特征
|
||
|
||
Args:
|
||
col: 列数据
|
||
col_name: 列名
|
||
max_unique_values: 显示的最大唯一值数量
|
||
|
||
Returns:
|
||
列画像
|
||
"""
|
||
non_null = col.dropna()
|
||
missing_count = int(col.isna().sum())
|
||
total_count = len(col)
|
||
missing_rate = round(missing_count / total_count * 100, 2) if total_count > 0 else 0
|
||
unique_count = int(non_null.nunique())
|
||
|
||
col_type = infer_column_type(col, unique_count, total_count)
|
||
|
||
profile = {
|
||
'name': col_name,
|
||
'type': col_type,
|
||
'missingCount': missing_count,
|
||
'missingRate': missing_rate,
|
||
'uniqueCount': unique_count,
|
||
'totalCount': total_count
|
||
}
|
||
|
||
if col_type == 'numeric':
|
||
profile.update(analyze_numeric_column(non_null))
|
||
elif col_type == 'categorical':
|
||
profile.update(analyze_categorical_column(non_null, max_unique_values))
|
||
elif col_type == 'datetime':
|
||
profile.update(analyze_datetime_column(non_null))
|
||
|
||
profile['isIdLike'] = _detect_id_like(col_name, col_type, unique_count, total_count)
|
||
|
||
return profile
|
||
|
||
|
||
import re
|
||
|
||
_ID_PATTERNS = re.compile(
|
||
r'(_id|_no|_code|编号|序号|流水号|主键|record_date|visit_date|enroll_date)$|^(id|ID|Id)_|^(patient|subject|sample|record)_?id$',
|
||
re.IGNORECASE
|
||
)
|
||
|
||
|
||
def _detect_id_like(col_name: str, col_type: str, unique_count: int, total_count: int) -> bool:
|
||
"""
|
||
判断列是否为非分析变量(ID / 高基数字符串 / 日期)
|
||
标记为 True 后,Q 层 Context Pruning 会在注入 Prompt 前物理剔除这些列
|
||
"""
|
||
if col_type == 'datetime':
|
||
return True
|
||
if _ID_PATTERNS.search(col_name):
|
||
return True
|
||
if col_type == 'text' and total_count > 0 and unique_count / total_count > 0.95:
|
||
return True
|
||
if col_type == 'categorical' and total_count > 0 and unique_count / total_count > 0.95:
|
||
return True
|
||
return False
|
||
|
||
|
||
def infer_column_type(col: pd.Series, unique_count: int, total_count: int) -> str:
|
||
"""
|
||
推断列的数据类型
|
||
|
||
Returns:
|
||
'numeric' | 'categorical' | 'datetime' | 'text'
|
||
"""
|
||
if pd.api.types.is_datetime64_any_dtype(col):
|
||
return 'datetime'
|
||
|
||
if pd.api.types.is_numeric_dtype(col):
|
||
unique_ratio = unique_count / total_count if total_count > 0 else 0
|
||
if unique_count <= 10 and unique_ratio < 0.05:
|
||
return 'categorical'
|
||
return 'numeric'
|
||
|
||
if col.dtype == 'object' or col.dtype == 'string':
|
||
non_null = col.dropna()
|
||
if len(non_null) == 0:
|
||
return 'text'
|
||
|
||
unique_ratio = unique_count / total_count if total_count > 0 else 0
|
||
if unique_count <= 30 and unique_ratio < 0.1:
|
||
return 'categorical'
|
||
|
||
try:
|
||
pd.to_numeric(non_null, errors='raise')
|
||
return 'numeric'
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
pd.to_datetime(non_null, errors='raise')
|
||
return 'datetime'
|
||
except:
|
||
pass
|
||
|
||
return 'text'
|
||
|
||
return 'text'
|
||
|
||
|
||
def analyze_numeric_column(col: pd.Series) -> Dict[str, Any]:
|
||
"""
|
||
分析数值列的统计特征
|
||
"""
|
||
if len(col) == 0:
|
||
return {}
|
||
|
||
col_numeric = pd.to_numeric(col, errors='coerce').dropna()
|
||
|
||
if len(col_numeric) == 0:
|
||
return {}
|
||
|
||
q1 = float(col_numeric.quantile(0.25))
|
||
q3 = float(col_numeric.quantile(0.75))
|
||
iqr = q3 - q1
|
||
lower_bound = q1 - 1.5 * iqr
|
||
upper_bound = q3 + 1.5 * iqr
|
||
outlier_count = int(((col_numeric < lower_bound) | (col_numeric > upper_bound)).sum())
|
||
|
||
return {
|
||
'mean': round(float(col_numeric.mean()), 4),
|
||
'std': round(float(col_numeric.std()), 4),
|
||
'median': round(float(col_numeric.median()), 4),
|
||
'min': round(float(col_numeric.min()), 4),
|
||
'max': round(float(col_numeric.max()), 4),
|
||
'q1': round(q1, 4),
|
||
'q3': round(q3, 4),
|
||
'iqr': round(iqr, 4),
|
||
'outlierCount': outlier_count,
|
||
'outlierRate': round(outlier_count / len(col_numeric) * 100, 2) if len(col_numeric) > 0 else 0,
|
||
'skewness': round(float(col_numeric.skew()), 4) if len(col_numeric) >= 3 else None,
|
||
'kurtosis': round(float(col_numeric.kurtosis()), 4) if len(col_numeric) >= 4 else None
|
||
}
|
||
|
||
|
||
def analyze_categorical_column(col: pd.Series, max_values: int = 20) -> Dict[str, Any]:
|
||
"""
|
||
分析分类列的统计特征
|
||
"""
|
||
if len(col) == 0:
|
||
return {}
|
||
|
||
value_counts = col.value_counts()
|
||
total = len(col)
|
||
|
||
top_values = []
|
||
for value, count in value_counts.head(max_values).items():
|
||
top_values.append({
|
||
'value': str(value),
|
||
'count': int(count),
|
||
'percentage': round(count / total * 100, 2)
|
||
})
|
||
|
||
return {
|
||
'topValues': top_values,
|
||
'totalLevels': int(len(value_counts)),
|
||
'modeValue': str(value_counts.index[0]) if len(value_counts) > 0 else None,
|
||
'modeCount': int(value_counts.iloc[0]) if len(value_counts) > 0 else 0
|
||
}
|
||
|
||
|
||
def analyze_datetime_column(col: pd.Series) -> Dict[str, Any]:
|
||
"""
|
||
分析日期时间列的统计特征
|
||
"""
|
||
if len(col) == 0:
|
||
return {}
|
||
|
||
try:
|
||
col_dt = pd.to_datetime(col, errors='coerce').dropna()
|
||
|
||
if len(col_dt) == 0:
|
||
return {}
|
||
|
||
return {
|
||
'minDate': col_dt.min().isoformat(),
|
||
'maxDate': col_dt.max().isoformat(),
|
||
'dateRange': str(col_dt.max() - col_dt.min())
|
||
}
|
||
except:
|
||
return {}
|
||
|
||
|
||
def get_quality_score(profile: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
计算数据质量评分
|
||
|
||
Returns:
|
||
质量评分和建议
|
||
"""
|
||
summary = profile.get('summary', {})
|
||
columns = profile.get('columns', [])
|
||
|
||
score = 100.0
|
||
issues = []
|
||
recommendations = []
|
||
|
||
overall_missing_rate = summary.get('overallMissingRate', 0)
|
||
if overall_missing_rate > 20:
|
||
score -= 30
|
||
issues.append(f"整体缺失率较高 ({overall_missing_rate}%)")
|
||
recommendations.append("建议检查数据完整性,考虑缺失值处理")
|
||
elif overall_missing_rate > 10:
|
||
score -= 15
|
||
issues.append(f"整体缺失率中等 ({overall_missing_rate}%)")
|
||
recommendations.append("建议在分析前处理缺失值")
|
||
elif overall_missing_rate > 5:
|
||
score -= 5
|
||
issues.append(f"存在少量缺失 ({overall_missing_rate}%)")
|
||
|
||
for col in columns:
|
||
if col.get('outlierRate', 0) > 10:
|
||
score -= 5
|
||
issues.append(f"列 '{col['name']}' 存在较多异常值 ({col['outlierRate']}%)")
|
||
recommendations.append(f"建议检查列 '{col['name']}' 的异常值")
|
||
|
||
total_rows = summary.get('totalRows', 0)
|
||
if total_rows < 30:
|
||
score -= 20
|
||
issues.append(f"样本量较小 (n={total_rows})")
|
||
recommendations.append("小样本可能影响统计检验的效力")
|
||
elif total_rows < 100:
|
||
score -= 10
|
||
issues.append(f"样本量中等 (n={total_rows})")
|
||
|
||
score = max(0, min(100, score))
|
||
|
||
if score >= 80:
|
||
grade = 'A'
|
||
grade_desc = '数据质量良好'
|
||
elif score >= 60:
|
||
grade = 'B'
|
||
grade_desc = '数据质量中等'
|
||
elif score >= 40:
|
||
grade = 'C'
|
||
grade_desc = '数据质量较差'
|
||
else:
|
||
grade = 'D'
|
||
grade_desc = '数据质量很差'
|
||
|
||
return {
|
||
'score': round(score, 1),
|
||
'grade': grade,
|
||
'gradeDescription': grade_desc,
|
||
'issues': issues,
|
||
'recommendations': recommendations
|
||
}
|
||
|
||
|
||
# ────────────────────────────────────────────
|
||
# Phase I 新增函数
|
||
# ────────────────────────────────────────────
|
||
|
||
def compute_normality_tests(df: pd.DataFrame, columns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||
"""
|
||
对所有数值列执行正态性检验。
|
||
样本量 <= 5000 用 Shapiro-Wilk,> 5000 降级为 Kolmogorov-Smirnov。
|
||
"""
|
||
results = []
|
||
numeric_cols = [c['name'] for c in columns if c['type'] == 'numeric']
|
||
|
||
for col_name in numeric_cols:
|
||
try:
|
||
col_data = pd.to_numeric(df[col_name], errors='coerce').dropna()
|
||
if len(col_data) < 3:
|
||
continue
|
||
|
||
if len(col_data) <= 5000:
|
||
stat, p_value = scipy_stats.shapiro(col_data)
|
||
method = 'shapiro_wilk'
|
||
else:
|
||
stat, p_value = scipy_stats.kstest(col_data, 'norm',
|
||
args=(col_data.mean(), col_data.std()))
|
||
method = 'kolmogorov_smirnov'
|
||
|
||
results.append({
|
||
'variable': col_name,
|
||
'method': method,
|
||
'statistic': round(float(stat), 4),
|
||
'pValue': round(float(p_value), 4),
|
||
'isNormal': bool(p_value >= 0.05)
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"正态性检验失败 [{col_name}]: {e}")
|
||
|
||
return results
|
||
|
||
|
||
def compute_complete_cases(df: pd.DataFrame) -> int:
|
||
"""返回无任何缺失值的完整病例数。"""
|
||
return int(df.dropna().shape[0])
|
||
|
||
|
||
def analyze_variable_detail(df: pd.DataFrame, variable_name: str,
|
||
max_bins: int = 30, max_qq_points: int = 200) -> Dict[str, Any]:
|
||
"""
|
||
单变量详细分析(Phase I: get_variable_detail 工具后端)。
|
||
|
||
返回:描述统计 + 分布直方图数据 + 正态性检验 + Q-Q 图数据点。
|
||
直方图 bins 强制上限 max_bins(H2 防护),Q-Q 点上限 max_qq_points。
|
||
"""
|
||
if variable_name not in df.columns:
|
||
return {'success': False, 'error': f"变量 '{variable_name}' 不存在"}
|
||
|
||
col = df[variable_name]
|
||
non_null = col.dropna()
|
||
total = len(col)
|
||
missing = int(col.isna().sum())
|
||
unique_count = int(non_null.nunique())
|
||
col_type = infer_column_type(col, unique_count, total)
|
||
|
||
result: Dict[str, Any] = {
|
||
'success': True,
|
||
'variable': variable_name,
|
||
'type': col_type,
|
||
'totalCount': total,
|
||
'missingCount': missing,
|
||
'missingRate': round(missing / total * 100, 2) if total > 0 else 0,
|
||
'uniqueCount': unique_count,
|
||
}
|
||
|
||
if col_type == 'numeric':
|
||
col_numeric = pd.to_numeric(non_null, errors='coerce').dropna()
|
||
if len(col_numeric) == 0:
|
||
result['descriptive'] = {}
|
||
return result
|
||
|
||
q1 = float(col_numeric.quantile(0.25))
|
||
q3 = float(col_numeric.quantile(0.75))
|
||
iqr_val = q3 - q1
|
||
lower_bound = q1 - 1.5 * iqr_val
|
||
upper_bound = q3 + 1.5 * iqr_val
|
||
outliers = col_numeric[(col_numeric < lower_bound) | (col_numeric > upper_bound)]
|
||
|
||
result['descriptive'] = {
|
||
'mean': round(float(col_numeric.mean()), 4),
|
||
'std': round(float(col_numeric.std()), 4),
|
||
'median': round(float(col_numeric.median()), 4),
|
||
'min': round(float(col_numeric.min()), 4),
|
||
'max': round(float(col_numeric.max()), 4),
|
||
'q1': round(q1, 4),
|
||
'q3': round(q3, 4),
|
||
'iqr': round(iqr_val, 4),
|
||
'skewness': round(float(col_numeric.skew()), 4) if len(col_numeric) >= 3 else None,
|
||
'kurtosis': round(float(col_numeric.kurtosis()), 4) if len(col_numeric) >= 4 else None,
|
||
}
|
||
|
||
result['outliers'] = {
|
||
'count': int(len(outliers)),
|
||
'rate': round(len(outliers) / len(col_numeric) * 100, 2),
|
||
'lowerBound': round(lower_bound, 4),
|
||
'upperBound': round(upper_bound, 4),
|
||
}
|
||
|
||
n_bins = min(max_bins, unique_count)
|
||
hist_counts, hist_edges = np.histogram(col_numeric, bins=max(n_bins, 1))
|
||
result['histogram'] = {
|
||
'counts': [int(c) for c in hist_counts],
|
||
'edges': [round(float(e), 4) for e in hist_edges],
|
||
}
|
||
|
||
if len(col_numeric) >= 3:
|
||
try:
|
||
if len(col_numeric) <= 5000:
|
||
stat, p_val = scipy_stats.shapiro(col_numeric)
|
||
method = 'shapiro_wilk'
|
||
else:
|
||
stat, p_val = scipy_stats.kstest(col_numeric, 'norm',
|
||
args=(col_numeric.mean(), col_numeric.std()))
|
||
method = 'kolmogorov_smirnov'
|
||
result['normalityTest'] = {
|
||
'method': method,
|
||
'statistic': round(float(stat), 4),
|
||
'pValue': round(float(p_val), 4),
|
||
'isNormal': bool(p_val >= 0.05),
|
||
}
|
||
except Exception:
|
||
result['normalityTest'] = None
|
||
|
||
sorted_data = np.sort(col_numeric.values)
|
||
n = len(sorted_data)
|
||
if n > max_qq_points:
|
||
indices = np.linspace(0, n - 1, max_qq_points, dtype=int)
|
||
sampled = sorted_data[indices]
|
||
else:
|
||
sampled = sorted_data
|
||
theoretical = scipy_stats.norm.ppf(
|
||
np.linspace(1 / (len(sampled) + 1), len(sampled) / (len(sampled) + 1), len(sampled))
|
||
)
|
||
result['qqPlot'] = {
|
||
'theoretical': [round(float(t), 4) for t in theoretical],
|
||
'observed': [round(float(o), 4) for o in sampled],
|
||
}
|
||
|
||
elif col_type == 'categorical':
|
||
value_counts = non_null.value_counts()
|
||
total_non_null = len(non_null)
|
||
result['distribution'] = [
|
||
{
|
||
'value': str(val),
|
||
'count': int(cnt),
|
||
'percentage': round(cnt / total_non_null * 100, 2)
|
||
}
|
||
for val, cnt in value_counts.items()
|
||
]
|
||
result['descriptive'] = {
|
||
'totalLevels': int(len(value_counts)),
|
||
'modeValue': str(value_counts.index[0]) if len(value_counts) > 0 else None,
|
||
'modeCount': int(value_counts.iloc[0]) if len(value_counts) > 0 else 0,
|
||
}
|
||
|
||
return result
|