Files
AIclinicalresearch/r-statistics-service/tools/descriptive.R
HaHafeng 371e1c069c feat(ssa): Complete QPER architecture - Query, Planner, Execute, Reflection layers
Implement the full QPER intelligent analysis pipeline:

- Phase E+: Block-based standardization for all 7 R tools, DynamicReport renderer, Word export enhancement

- Phase Q: LLM intent parsing with dynamic Zod validation against real column names, ClarificationCard component, DataProfile is_id_like tagging

- Phase P: ConfigLoader with Zod schema validation and hot-reload API, DecisionTableService (4-dimension matching), FlowTemplateService with EPV protection, PlannedTrace audit output

- Phase R: ReflectionService with statistical slot injection, sensitivity analysis conflict rules, ConclusionReport with section reveal animation, conclusion caching API, graceful R error classification

End-to-end test: 40/40 passed across two complete analysis scenarios.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-21 18:15:53 +08:00

410 lines
12 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#' @tool_code ST_DESCRIPTIVE
#' @name 描述性统计
#' @version 1.0.0
#' @description 数据概况与基线特征表
#' @author SSA-Pro Team
library(glue)
library(ggplot2)
library(base64enc)
run_analysis <- function(input) {
# ===== 初始化 =====
logs <- c()
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
on.exit({}, add = TRUE)
# ===== 数据加载 =====
log_add("开始加载输入数据")
df <- tryCatch(
load_input_data(input),
error = function(e) {
log_add(paste("数据加载失败:", e$message))
return(NULL)
}
)
if (is.null(df)) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
}
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
p <- input$params
variables <- p$variables # 变量列表(可选,空则分析全部)
group_var <- p$group_var # 分组变量(可选)
# ===== 确定要分析的变量 =====
if (is.null(variables) || length(variables) == 0) {
variables <- names(df)
log_add("未指定变量,分析全部列")
}
# 排除分组变量本身
if (!is.null(group_var) && group_var %in% variables) {
variables <- setdiff(variables, group_var)
}
# 校验变量存在性
missing_vars <- setdiff(variables, names(df))
if (length(missing_vars) > 0) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND,
col = paste(missing_vars, collapse = ", ")))
}
# 校验分组变量
groups <- NULL
if (!is.null(group_var) && group_var != "") {
if (!(group_var %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = group_var))
}
groups <- unique(df[[group_var]][!is.na(df[[group_var]])])
log_add(glue("分组变量: {group_var}, 分组: {paste(groups, collapse=', ')}"))
}
# ===== 变量类型推断 =====
var_types <- sapply(variables, function(v) {
vals <- df[[v]]
if (is.numeric(vals)) {
non_na_count <- sum(!is.na(vals))
if (non_na_count == 0) {
return("categorical") # 全是 NA当作分类变量
}
unique_count <- length(unique(vals[!is.na(vals)]))
unique_ratio <- unique_count / non_na_count
if (unique_ratio < 0.05 && unique_count <= 10) {
return("categorical")
}
return("numeric")
} else {
return("categorical")
}
})
log_add(glue("数值变量: {sum(var_types == 'numeric')}, 分类变量: {sum(var_types == 'categorical')}"))
# ===== 计算描述性统计 =====
warnings_list <- c()
results_list <- list()
for (v in variables) {
var_type <- as.character(var_types[v])
if (is.na(var_type) || length(var_type) == 0) {
var_type <- "categorical" # 默认为分类变量
}
if (is.null(groups)) {
# 无分组
if (identical(var_type, "numeric")) {
stats <- calc_numeric_stats(df[[v]], v)
} else {
stats <- calc_categorical_stats(df[[v]], v)
}
stats$type <- var_type
results_list[[v]] <- stats
} else {
# 有分组
group_stats <- list()
for (g in groups) {
subset_vals <- df[df[[group_var]] == g, v, drop = TRUE]
if (identical(var_type, "numeric")) {
group_stats[[as.character(g)]] <- calc_numeric_stats(subset_vals, v)
} else {
group_stats[[as.character(g)]] <- calc_categorical_stats(subset_vals, v)
}
}
results_list[[v]] <- list(
variable = v,
type = var_type,
by_group = group_stats
)
}
}
# ===== 总体概况 =====
summary_stats <- list(
n_total = nrow(df),
n_variables = length(variables),
n_numeric = sum(var_types == "numeric"),
n_categorical = sum(var_types == "categorical")
)
if (!is.null(groups)) {
summary_stats$group_var <- group_var
summary_stats$groups <- lapply(groups, function(g) {
list(name = as.character(g), n = sum(df[[group_var]] == g, na.rm = TRUE))
})
}
# ===== 生成图表 =====
log_add("生成描述性统计图表")
plots <- list()
# 只为前几个变量生成图表(避免过多)
vars_to_plot <- head(variables, 4)
for (v in vars_to_plot) {
plot_base64 <- tryCatch({
if (var_types[v] == "numeric") {
generate_histogram(df, v, group_var)
} else {
generate_bar_chart(df, v, group_var)
}
}, error = function(e) {
log_add(paste("图表生成失败:", v, e$message))
NULL
})
if (!is.null(plot_base64)) {
plots <- c(plots, list(plot_base64))
}
}
# ===== 生成可复现代码 =====
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
input$original_filename
} else {
"data.csv"
}
reproducible_code <- glue('
# SSA-Pro 自动生成代码
# 工具: 描述性统计
# 时间: {Sys.time()}
# ================================
library(ggplot2)
# 数据准备
df <- read.csv("{original_filename}")
# 数值变量描述性统计
numeric_vars <- sapply(df, is.numeric)
if (any(numeric_vars)) {{
summary(df[, numeric_vars, drop = FALSE])
}}
# 分类变量频数表
categorical_vars <- !numeric_vars
if (any(categorical_vars)) {{
for (v in names(df)[categorical_vars]) {{
cat("\\n变量:", v, "\\n")
print(table(df[[v]], useNA = "ifany"))
}}
}}
# 可视化示例
# ggplot(df, aes(x = your_variable)) + geom_histogram()
')
# ===== 返回结果 =====
log_add("分析完成")
# ===== 构建 report_blocks =====
blocks <- list()
# Block 1: 数据概况
kv_items <- list(
"总样本量" = as.character(summary_stats$n_total),
"变量数" = as.character(summary_stats$n_variables),
"数值变量数" = as.character(summary_stats$n_numeric),
"分类变量数" = as.character(summary_stats$n_categorical)
)
if (!is.null(groups)) {
kv_items$group_var <- group_var
kv_items$groups <- paste(sapply(summary_stats$groups, function(g) paste0(g$name, "(n=", g$n, ")")), collapse = ", ")
}
blocks[[length(blocks) + 1]] <- make_kv_block(kv_items, title = "数据概况")
# Block 2: 数值变量汇总表
numeric_vars <- names(results_list)[sapply(results_list, function(x) {
if (is.list(x) && !is.null(x$type)) x$type == "numeric" else FALSE
})]
if (length(numeric_vars) > 0) {
if (is.null(groups)) {
num_headers <- c("变量名", "n", "mean", "sd", "median", "Q1", "Q3", "min", "max")
num_rows <- lapply(numeric_vars, function(v) {
s <- results_list[[v]]
c(v, as.character(s$n), as.character(s$mean), as.character(s$sd),
as.character(s$median), as.character(s$q1), as.character(s$q3),
as.character(s$min), as.character(s$max))
})
} else {
num_headers <- c("变量名", as.character(groups))
num_rows <- lapply(numeric_vars, function(v) {
s <- results_list[[v]]
row <- c(v)
for (g in groups) {
gs <- s$by_group[[as.character(g)]]
row <- c(row, if (!is.null(gs$formatted)) gs$formatted else "-")
}
row
})
}
blocks[[length(blocks) + 1]] <- make_table_block(num_headers, num_rows, title = "数值变量汇总表")
}
# Block 3: 分类变量汇总表
cat_vars <- names(results_list)[sapply(results_list, function(x) {
if (is.list(x) && !is.null(x$type)) x$type == "categorical" else FALSE
})]
if (length(cat_vars) > 0) {
cat_headers <- c("变量名", "水平", "n", "百分比")
cat_rows <- list()
for (v in cat_vars) {
s <- results_list[[v]]
if (is.null(groups)) {
for (lev in s$levels) {
cat_rows[[length(cat_rows) + 1]] <- c(v, lev$level, as.character(lev$n), paste0(lev$pct, "%"))
}
} else {
for (g in groups) {
gs <- s$by_group[[as.character(g)]]
for (lev in gs$levels) {
cat_rows[[length(cat_rows) + 1]] <- c(paste0(v, " (", g, ")"), lev$level, as.character(lev$n), paste0(lev$pct, "%"))
}
}
}
}
if (length(cat_rows) > 0) {
blocks[[length(blocks) + 1]] <- make_table_block(cat_headers, cat_rows, title = "分类变量汇总表")
}
}
# Block 4+: 各图表
for (i in seq_along(plots)) {
blocks[[length(blocks) + 1]] <- make_image_block(plots[[i]], title = paste0("图表 ", i), alt = paste0("描述性统计图 ", i))
}
return(list(
status = "success",
message = "分析完成",
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
results = list(
summary = summary_stats,
variables = results_list
),
report_blocks = blocks,
plots = plots,
trace_log = logs,
reproducible_code = as.character(reproducible_code)
))
}
# ===== 辅助函数 =====
# 数值变量统计
calc_numeric_stats <- function(vals, var_name) {
vals <- vals[!is.na(vals)]
n <- length(vals)
if (n == 0) {
return(list(
variable = var_name,
n = 0,
missing = length(vals) - n,
stats = NULL
))
}
list(
variable = var_name,
n = n,
missing = 0,
mean = round(mean(vals), 3),
sd = round(sd(vals), 3),
median = round(median(vals), 3),
q1 = round(quantile(vals, 0.25), 3),
q3 = round(quantile(vals, 0.75), 3),
iqr = round(IQR(vals), 3),
min = round(min(vals), 3),
max = round(max(vals), 3),
skewness = round(calc_skewness(vals), 3),
formatted = paste0(round(mean(vals), 2), " ± ", round(sd(vals), 2))
)
}
# 分类变量统计
calc_categorical_stats <- function(vals, var_name) {
total <- length(vals)
valid <- sum(!is.na(vals))
freq_table <- table(vals, useNA = "no")
levels_list <- lapply(names(freq_table), function(level) {
count <- as.numeric(freq_table[level])
pct <- round(count / valid * 100, 1)
list(
level = level,
n = count,
pct = pct,
formatted = paste0(count, " (", pct, "%)")
)
})
list(
variable = var_name,
n = valid,
missing = total - valid,
levels = levels_list
)
}
# 计算偏度
calc_skewness <- function(x) {
n <- length(x)
if (n < 3) return(NA)
m <- mean(x)
s <- sd(x)
sum((x - m)^3) / (n * s^3)
}
# 生成直方图
generate_histogram <- function(df, var_name, group_var = NULL) {
if (!is.null(group_var) && group_var != "") {
p <- ggplot(df[!is.na(df[[var_name]]), ], aes(x = .data[[var_name]], fill = factor(.data[[group_var]]))) +
geom_histogram(alpha = 0.6, position = "identity", bins = 30) +
scale_fill_brewer(palette = "Set1", name = group_var) +
theme_minimal()
} else {
p <- ggplot(df[!is.na(df[[var_name]]), ], aes(x = .data[[var_name]])) +
geom_histogram(fill = "#3b82f6", alpha = 0.7, bins = 30) +
theme_minimal()
}
p <- p + labs(title = paste("Distribution of", var_name), x = var_name, y = "Count")
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, p, width = 6, height = 4, dpi = 100)
base64_str <- base64encode(tmp_file)
unlink(tmp_file)
return(paste0("data:image/png;base64,", base64_str))
}
# 生成柱状图
generate_bar_chart <- function(df, var_name, group_var = NULL) {
df_plot <- df[!is.na(df[[var_name]]), ]
if (!is.null(group_var) && group_var != "") {
p <- ggplot(df_plot, aes(x = factor(.data[[var_name]]), fill = factor(.data[[group_var]]))) +
geom_bar(position = "dodge") +
scale_fill_brewer(palette = "Set1", name = group_var) +
theme_minimal()
} else {
p <- ggplot(df_plot, aes(x = factor(.data[[var_name]]))) +
geom_bar(fill = "#3b82f6", alpha = 0.7) +
theme_minimal()
}
p <- p + labs(title = paste("Frequency of", var_name), x = var_name, y = "Count") +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, p, width = 6, height = 4, dpi = 100)
base64_str <- base64encode(tmp_file)
unlink(tmp_file)
return(paste0("data:image/png;base64,", base64_str))
}