""" SSA DataProfile - 数据画像生成模块 (Phase 2A) 提供数据上传时的快速画像生成,用于 LLM 生成 SAP(分析计划)。 高性能实现,利用 pandas 的向量化操作。 """ import pandas as pd import numpy as np from typing import List, Dict, Any, Optional from loguru import logger def generate_data_profile(df: pd.DataFrame, max_unique_values: int = 20) -> Dict[str, Any]: """ 生成数据画像(DataProfile) Args: df: 输入数据框 max_unique_values: 分类变量显示的最大唯一值数量 Returns: DataProfile JSON 结构 """ logger.info(f"开始生成数据画像: {df.shape[0]} 行, {df.shape[1]} 列") columns = [] numeric_count = 0 categorical_count = 0 datetime_count = 0 for col_name in df.columns: col = df[col_name] col_profile = analyze_column(col, col_name, max_unique_values) columns.append(col_profile) if col_profile['type'] == 'numeric': numeric_count += 1 elif col_profile['type'] == 'categorical': categorical_count += 1 elif col_profile['type'] == 'datetime': datetime_count += 1 total_cells = df.shape[0] * df.shape[1] total_missing = df.isna().sum().sum() summary = { 'totalRows': int(df.shape[0]), 'totalColumns': int(df.shape[1]), 'numericColumns': numeric_count, 'categoricalColumns': categorical_count, 'datetimeColumns': datetime_count, 'textColumns': int(df.shape[1]) - numeric_count - categorical_count - datetime_count, 'overallMissingRate': round(total_missing / total_cells * 100, 2) if total_cells > 0 else 0, 'totalMissingCells': int(total_missing) } logger.info(f"数据画像生成完成: {numeric_count} 数值列, {categorical_count} 分类列") return { 'columns': columns, 'summary': summary } def analyze_column(col: pd.Series, col_name: str, max_unique_values: int = 20) -> Dict[str, Any]: """ 分析单个列的统计特征 Args: col: 列数据 col_name: 列名 max_unique_values: 显示的最大唯一值数量 Returns: 列画像 """ non_null = col.dropna() missing_count = int(col.isna().sum()) total_count = len(col) missing_rate = round(missing_count / total_count * 100, 2) if total_count > 0 else 0 unique_count = int(non_null.nunique()) col_type = infer_column_type(col, unique_count, total_count) profile = { 'name': col_name, 'type': col_type, 'missingCount': missing_count, 'missingRate': missing_rate, 'uniqueCount': unique_count, 'totalCount': total_count } if col_type == 'numeric': profile.update(analyze_numeric_column(non_null)) elif col_type == 'categorical': profile.update(analyze_categorical_column(non_null, max_unique_values)) elif col_type == 'datetime': profile.update(analyze_datetime_column(non_null)) profile['isIdLike'] = _detect_id_like(col_name, col_type, unique_count, total_count) return profile import re _ID_PATTERNS = re.compile( r'(_id|_no|_code|编号|序号|流水号|主键|record_date|visit_date|enroll_date)$|^(id|ID|Id)_|^(patient|subject|sample|record)_?id$', re.IGNORECASE ) def _detect_id_like(col_name: str, col_type: str, unique_count: int, total_count: int) -> bool: """ 判断列是否为非分析变量(ID / 高基数字符串 / 日期) 标记为 True 后,Q 层 Context Pruning 会在注入 Prompt 前物理剔除这些列 """ if col_type == 'datetime': return True if _ID_PATTERNS.search(col_name): return True if col_type == 'text' and total_count > 0 and unique_count / total_count > 0.95: return True if col_type == 'categorical' and total_count > 0 and unique_count / total_count > 0.95: return True return False def infer_column_type(col: pd.Series, unique_count: int, total_count: int) -> str: """ 推断列的数据类型 Returns: 'numeric' | 'categorical' | 'datetime' | 'text' """ if pd.api.types.is_datetime64_any_dtype(col): return 'datetime' if pd.api.types.is_numeric_dtype(col): unique_ratio = unique_count / total_count if total_count > 0 else 0 if unique_count <= 10 and unique_ratio < 0.05: return 'categorical' return 'numeric' if col.dtype == 'object' or col.dtype == 'string': non_null = col.dropna() if len(non_null) == 0: return 'text' unique_ratio = unique_count / total_count if total_count > 0 else 0 if unique_count <= 30 and unique_ratio < 0.1: return 'categorical' try: pd.to_numeric(non_null, errors='raise') return 'numeric' except: pass try: pd.to_datetime(non_null, errors='raise') return 'datetime' except: pass return 'text' return 'text' def analyze_numeric_column(col: pd.Series) -> Dict[str, Any]: """ 分析数值列的统计特征 """ if len(col) == 0: return {} col_numeric = pd.to_numeric(col, errors='coerce').dropna() if len(col_numeric) == 0: return {} q1 = float(col_numeric.quantile(0.25)) q3 = float(col_numeric.quantile(0.75)) iqr = q3 - q1 lower_bound = q1 - 1.5 * iqr upper_bound = q3 + 1.5 * iqr outlier_count = int(((col_numeric < lower_bound) | (col_numeric > upper_bound)).sum()) return { 'mean': round(float(col_numeric.mean()), 4), 'std': round(float(col_numeric.std()), 4), 'median': round(float(col_numeric.median()), 4), 'min': round(float(col_numeric.min()), 4), 'max': round(float(col_numeric.max()), 4), 'q1': round(q1, 4), 'q3': round(q3, 4), 'iqr': round(iqr, 4), 'outlierCount': outlier_count, 'outlierRate': round(outlier_count / len(col_numeric) * 100, 2) if len(col_numeric) > 0 else 0, 'skewness': round(float(col_numeric.skew()), 4) if len(col_numeric) >= 3 else None, 'kurtosis': round(float(col_numeric.kurtosis()), 4) if len(col_numeric) >= 4 else None } def analyze_categorical_column(col: pd.Series, max_values: int = 20) -> Dict[str, Any]: """ 分析分类列的统计特征 """ if len(col) == 0: return {} value_counts = col.value_counts() total = len(col) top_values = [] for value, count in value_counts.head(max_values).items(): top_values.append({ 'value': str(value), 'count': int(count), 'percentage': round(count / total * 100, 2) }) return { 'topValues': top_values, 'totalLevels': int(len(value_counts)), 'modeValue': str(value_counts.index[0]) if len(value_counts) > 0 else None, 'modeCount': int(value_counts.iloc[0]) if len(value_counts) > 0 else 0 } def analyze_datetime_column(col: pd.Series) -> Dict[str, Any]: """ 分析日期时间列的统计特征 """ if len(col) == 0: return {} try: col_dt = pd.to_datetime(col, errors='coerce').dropna() if len(col_dt) == 0: return {} return { 'minDate': col_dt.min().isoformat(), 'maxDate': col_dt.max().isoformat(), 'dateRange': str(col_dt.max() - col_dt.min()) } except: return {} def get_quality_score(profile: Dict[str, Any]) -> Dict[str, Any]: """ 计算数据质量评分 Returns: 质量评分和建议 """ summary = profile.get('summary', {}) columns = profile.get('columns', []) score = 100.0 issues = [] recommendations = [] overall_missing_rate = summary.get('overallMissingRate', 0) if overall_missing_rate > 20: score -= 30 issues.append(f"整体缺失率较高 ({overall_missing_rate}%)") recommendations.append("建议检查数据完整性,考虑缺失值处理") elif overall_missing_rate > 10: score -= 15 issues.append(f"整体缺失率中等 ({overall_missing_rate}%)") recommendations.append("建议在分析前处理缺失值") elif overall_missing_rate > 5: score -= 5 issues.append(f"存在少量缺失 ({overall_missing_rate}%)") for col in columns: if col.get('outlierRate', 0) > 10: score -= 5 issues.append(f"列 '{col['name']}' 存在较多异常值 ({col['outlierRate']}%)") recommendations.append(f"建议检查列 '{col['name']}' 的异常值") total_rows = summary.get('totalRows', 0) if total_rows < 30: score -= 20 issues.append(f"样本量较小 (n={total_rows})") recommendations.append("小样本可能影响统计检验的效力") elif total_rows < 100: score -= 10 issues.append(f"样本量中等 (n={total_rows})") score = max(0, min(100, score)) if score >= 80: grade = 'A' grade_desc = '数据质量良好' elif score >= 60: grade = 'B' grade_desc = '数据质量中等' elif score >= 40: grade = 'C' grade_desc = '数据质量较差' else: grade = 'D' grade_desc = '数据质量很差' return { 'score': round(score, 1), 'grade': grade, 'gradeDescription': grade_desc, 'issues': issues, 'recommendations': recommendations }