""" SSA DataProfile - 数据画像生成模块 (Phase 2A → Phase I) 提供数据上传时的快速画像生成,用于 LLM 生成 SAP(分析计划)。 高性能实现,利用 pandas 的向量化操作。 Phase I 新增: - compute_normality_tests(df) — Shapiro-Wilk / K-S 正态性检验 - compute_complete_cases(df) — 完整病例计数 - analyze_variable_detail() — 单变量详细分析(直方图+Q-Q图数据) """ import pandas as pd import numpy as np from scipy import stats as scipy_stats from typing import List, Dict, Any, Optional from loguru import logger def generate_data_profile(df: pd.DataFrame, max_unique_values: int = 20) -> Dict[str, Any]: """ 生成数据画像(DataProfile) Args: df: 输入数据框 max_unique_values: 分类变量显示的最大唯一值数量 Returns: DataProfile JSON 结构 """ logger.info(f"开始生成数据画像: {df.shape[0]} 行, {df.shape[1]} 列") columns = [] numeric_count = 0 categorical_count = 0 datetime_count = 0 for col_name in df.columns: col = df[col_name] col_profile = analyze_column(col, col_name, max_unique_values) columns.append(col_profile) if col_profile['type'] == 'numeric': numeric_count += 1 elif col_profile['type'] == 'categorical': categorical_count += 1 elif col_profile['type'] == 'datetime': datetime_count += 1 total_cells = df.shape[0] * df.shape[1] total_missing = df.isna().sum().sum() summary = { 'totalRows': int(df.shape[0]), 'totalColumns': int(df.shape[1]), 'numericColumns': numeric_count, 'categoricalColumns': categorical_count, 'datetimeColumns': datetime_count, 'textColumns': int(df.shape[1]) - numeric_count - categorical_count - datetime_count, 'overallMissingRate': round(total_missing / total_cells * 100, 2) if total_cells > 0 else 0, 'totalMissingCells': int(total_missing) } normality_tests = compute_normality_tests(df, columns) complete_case_count = compute_complete_cases(df) logger.info(f"数据画像生成完成: {numeric_count} 数值列, {categorical_count} 分类列, 完整病例 {complete_case_count}") return { 'columns': columns, 'summary': summary, 'normalityTests': normality_tests, 'completeCaseCount': complete_case_count } def analyze_column(col: pd.Series, col_name: str, max_unique_values: int = 20) -> Dict[str, Any]: """ 分析单个列的统计特征 Args: col: 列数据 col_name: 列名 max_unique_values: 显示的最大唯一值数量 Returns: 列画像 """ non_null = col.dropna() missing_count = int(col.isna().sum()) total_count = len(col) missing_rate = round(missing_count / total_count * 100, 2) if total_count > 0 else 0 unique_count = int(non_null.nunique()) col_type = infer_column_type(col, unique_count, total_count) profile = { 'name': col_name, 'type': col_type, 'missingCount': missing_count, 'missingRate': missing_rate, 'uniqueCount': unique_count, 'totalCount': total_count } if col_type == 'numeric': profile.update(analyze_numeric_column(non_null)) elif col_type == 'categorical': profile.update(analyze_categorical_column(non_null, max_unique_values)) elif col_type == 'datetime': profile.update(analyze_datetime_column(non_null)) profile['isIdLike'] = _detect_id_like(col_name, col_type, unique_count, total_count) return profile import re _ID_PATTERNS = re.compile( r'(_id|_no|_code|编号|序号|流水号|主键|record_date|visit_date|enroll_date)$|^(id|ID|Id)_|^(patient|subject|sample|record)_?id$', re.IGNORECASE ) def _detect_id_like(col_name: str, col_type: str, unique_count: int, total_count: int) -> bool: """ 判断列是否为非分析变量(ID / 高基数字符串 / 日期) 标记为 True 后,Q 层 Context Pruning 会在注入 Prompt 前物理剔除这些列 """ if col_type == 'datetime': return True if _ID_PATTERNS.search(col_name): return True if col_type == 'text' and total_count > 0 and unique_count / total_count > 0.95: return True if col_type == 'categorical' and total_count > 0 and unique_count / total_count > 0.95: return True return False def infer_column_type(col: pd.Series, unique_count: int, total_count: int) -> str: """ 推断列的数据类型 Returns: 'numeric' | 'categorical' | 'datetime' | 'text' """ if pd.api.types.is_datetime64_any_dtype(col): return 'datetime' if pd.api.types.is_numeric_dtype(col): unique_ratio = unique_count / total_count if total_count > 0 else 0 if unique_count <= 10 and unique_ratio < 0.05: return 'categorical' return 'numeric' if col.dtype == 'object' or col.dtype == 'string': non_null = col.dropna() if len(non_null) == 0: return 'text' unique_ratio = unique_count / total_count if total_count > 0 else 0 if unique_count <= 30 and unique_ratio < 0.1: return 'categorical' try: pd.to_numeric(non_null, errors='raise') return 'numeric' except: pass try: pd.to_datetime(non_null, errors='raise') return 'datetime' except: pass return 'text' return 'text' def analyze_numeric_column(col: pd.Series) -> Dict[str, Any]: """ 分析数值列的统计特征 """ if len(col) == 0: return {} col_numeric = pd.to_numeric(col, errors='coerce').dropna() if len(col_numeric) == 0: return {} q1 = float(col_numeric.quantile(0.25)) q3 = float(col_numeric.quantile(0.75)) iqr = q3 - q1 lower_bound = q1 - 1.5 * iqr upper_bound = q3 + 1.5 * iqr outlier_count = int(((col_numeric < lower_bound) | (col_numeric > upper_bound)).sum()) return { 'mean': round(float(col_numeric.mean()), 4), 'std': round(float(col_numeric.std()), 4), 'median': round(float(col_numeric.median()), 4), 'min': round(float(col_numeric.min()), 4), 'max': round(float(col_numeric.max()), 4), 'q1': round(q1, 4), 'q3': round(q3, 4), 'iqr': round(iqr, 4), 'outlierCount': outlier_count, 'outlierRate': round(outlier_count / len(col_numeric) * 100, 2) if len(col_numeric) > 0 else 0, 'skewness': round(float(col_numeric.skew()), 4) if len(col_numeric) >= 3 else None, 'kurtosis': round(float(col_numeric.kurtosis()), 4) if len(col_numeric) >= 4 else None } def analyze_categorical_column(col: pd.Series, max_values: int = 20) -> Dict[str, Any]: """ 分析分类列的统计特征 """ if len(col) == 0: return {} value_counts = col.value_counts() total = len(col) top_values = [] for value, count in value_counts.head(max_values).items(): top_values.append({ 'value': str(value), 'count': int(count), 'percentage': round(count / total * 100, 2) }) return { 'topValues': top_values, 'totalLevels': int(len(value_counts)), 'modeValue': str(value_counts.index[0]) if len(value_counts) > 0 else None, 'modeCount': int(value_counts.iloc[0]) if len(value_counts) > 0 else 0 } def analyze_datetime_column(col: pd.Series) -> Dict[str, Any]: """ 分析日期时间列的统计特征 """ if len(col) == 0: return {} try: col_dt = pd.to_datetime(col, errors='coerce').dropna() if len(col_dt) == 0: return {} return { 'minDate': col_dt.min().isoformat(), 'maxDate': col_dt.max().isoformat(), 'dateRange': str(col_dt.max() - col_dt.min()) } except: return {} def get_quality_score(profile: Dict[str, Any]) -> Dict[str, Any]: """ 计算数据质量评分 Returns: 质量评分和建议 """ summary = profile.get('summary', {}) columns = profile.get('columns', []) score = 100.0 issues = [] recommendations = [] overall_missing_rate = summary.get('overallMissingRate', 0) if overall_missing_rate > 20: score -= 30 issues.append(f"整体缺失率较高 ({overall_missing_rate}%)") recommendations.append("建议检查数据完整性,考虑缺失值处理") elif overall_missing_rate > 10: score -= 15 issues.append(f"整体缺失率中等 ({overall_missing_rate}%)") recommendations.append("建议在分析前处理缺失值") elif overall_missing_rate > 5: score -= 5 issues.append(f"存在少量缺失 ({overall_missing_rate}%)") for col in columns: if col.get('outlierRate', 0) > 10: score -= 5 issues.append(f"列 '{col['name']}' 存在较多异常值 ({col['outlierRate']}%)") recommendations.append(f"建议检查列 '{col['name']}' 的异常值") total_rows = summary.get('totalRows', 0) if total_rows < 30: score -= 20 issues.append(f"样本量较小 (n={total_rows})") recommendations.append("小样本可能影响统计检验的效力") elif total_rows < 100: score -= 10 issues.append(f"样本量中等 (n={total_rows})") score = max(0, min(100, score)) if score >= 80: grade = 'A' grade_desc = '数据质量良好' elif score >= 60: grade = 'B' grade_desc = '数据质量中等' elif score >= 40: grade = 'C' grade_desc = '数据质量较差' else: grade = 'D' grade_desc = '数据质量很差' return { 'score': round(score, 1), 'grade': grade, 'gradeDescription': grade_desc, 'issues': issues, 'recommendations': recommendations } # ──────────────────────────────────────────── # Phase I 新增函数 # ──────────────────────────────────────────── def compute_normality_tests(df: pd.DataFrame, columns: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 对所有数值列执行正态性检验。 样本量 <= 5000 用 Shapiro-Wilk,> 5000 降级为 Kolmogorov-Smirnov。 """ results = [] numeric_cols = [c['name'] for c in columns if c['type'] == 'numeric'] for col_name in numeric_cols: try: col_data = pd.to_numeric(df[col_name], errors='coerce').dropna() if len(col_data) < 3: continue if len(col_data) <= 5000: stat, p_value = scipy_stats.shapiro(col_data) method = 'shapiro_wilk' else: stat, p_value = scipy_stats.kstest(col_data, 'norm', args=(col_data.mean(), col_data.std())) method = 'kolmogorov_smirnov' results.append({ 'variable': col_name, 'method': method, 'statistic': round(float(stat), 4), 'pValue': round(float(p_value), 4), 'isNormal': bool(p_value >= 0.05) }) except Exception as e: logger.warning(f"正态性检验失败 [{col_name}]: {e}") return results def compute_complete_cases(df: pd.DataFrame) -> int: """返回无任何缺失值的完整病例数。""" return int(df.dropna().shape[0]) def analyze_variable_detail(df: pd.DataFrame, variable_name: str, max_bins: int = 30, max_qq_points: int = 200) -> Dict[str, Any]: """ 单变量详细分析(Phase I: get_variable_detail 工具后端)。 返回:描述统计 + 分布直方图数据 + 正态性检验 + Q-Q 图数据点。 直方图 bins 强制上限 max_bins(H2 防护),Q-Q 点上限 max_qq_points。 """ if variable_name not in df.columns: return {'success': False, 'error': f"变量 '{variable_name}' 不存在"} col = df[variable_name] non_null = col.dropna() total = len(col) missing = int(col.isna().sum()) unique_count = int(non_null.nunique()) col_type = infer_column_type(col, unique_count, total) result: Dict[str, Any] = { 'success': True, 'variable': variable_name, 'type': col_type, 'totalCount': total, 'missingCount': missing, 'missingRate': round(missing / total * 100, 2) if total > 0 else 0, 'uniqueCount': unique_count, } if col_type == 'numeric': col_numeric = pd.to_numeric(non_null, errors='coerce').dropna() if len(col_numeric) == 0: result['descriptive'] = {} return result q1 = float(col_numeric.quantile(0.25)) q3 = float(col_numeric.quantile(0.75)) iqr_val = q3 - q1 lower_bound = q1 - 1.5 * iqr_val upper_bound = q3 + 1.5 * iqr_val outliers = col_numeric[(col_numeric < lower_bound) | (col_numeric > upper_bound)] result['descriptive'] = { 'mean': round(float(col_numeric.mean()), 4), 'std': round(float(col_numeric.std()), 4), 'median': round(float(col_numeric.median()), 4), 'min': round(float(col_numeric.min()), 4), 'max': round(float(col_numeric.max()), 4), 'q1': round(q1, 4), 'q3': round(q3, 4), 'iqr': round(iqr_val, 4), 'skewness': round(float(col_numeric.skew()), 4) if len(col_numeric) >= 3 else None, 'kurtosis': round(float(col_numeric.kurtosis()), 4) if len(col_numeric) >= 4 else None, } result['outliers'] = { 'count': int(len(outliers)), 'rate': round(len(outliers) / len(col_numeric) * 100, 2), 'lowerBound': round(lower_bound, 4), 'upperBound': round(upper_bound, 4), } n_bins = min(max_bins, unique_count) hist_counts, hist_edges = np.histogram(col_numeric, bins=max(n_bins, 1)) result['histogram'] = { 'counts': [int(c) for c in hist_counts], 'edges': [round(float(e), 4) for e in hist_edges], } if len(col_numeric) >= 3: try: if len(col_numeric) <= 5000: stat, p_val = scipy_stats.shapiro(col_numeric) method = 'shapiro_wilk' else: stat, p_val = scipy_stats.kstest(col_numeric, 'norm', args=(col_numeric.mean(), col_numeric.std())) method = 'kolmogorov_smirnov' result['normalityTest'] = { 'method': method, 'statistic': round(float(stat), 4), 'pValue': round(float(p_val), 4), 'isNormal': bool(p_val >= 0.05), } except Exception: result['normalityTest'] = None sorted_data = np.sort(col_numeric.values) n = len(sorted_data) if n > max_qq_points: indices = np.linspace(0, n - 1, max_qq_points, dtype=int) sampled = sorted_data[indices] else: sampled = sorted_data theoretical = scipy_stats.norm.ppf( np.linspace(1 / (len(sampled) + 1), len(sampled) / (len(sampled) + 1), len(sampled)) ) result['qqPlot'] = { 'theoretical': [round(float(t), 4) for t in theoretical], 'observed': [round(float(o), 4) for o in sampled], } elif col_type == 'categorical': value_counts = non_null.value_counts() total_non_null = len(non_null) result['distribution'] = [ { 'value': str(val), 'count': int(cnt), 'percentage': round(cnt / total_non_null * 100, 2) } for val, cnt in value_counts.items() ] result['descriptive'] = { 'totalLevels': int(len(value_counts)), 'modeValue': str(value_counts.index[0]) if len(value_counts) > 0 else None, 'modeCount': int(value_counts.iloc[0]) if len(value_counts) > 0 else 0, } return result