Files
AIclinicalresearch/r-statistics-service/tools/descriptive.R
HaHafeng 428a22adf2 feat(ssa): Complete Phase 2A frontend integration - multi-step workflow end-to-end
Phase 2A: WorkflowPlannerService, WorkflowExecutorService, Python data quality, 6 bug fixes, DescriptiveResultView, multi-step R code/Word export, MVP UI reuse. V11 UI: Gemini-style, multi-task, single-page scroll, Word export. Architecture: Block-based rendering consensus (4 block types). New R tools: chi_square, correlation, descriptive, logistic_binary, mann_whitney, t_test_paired. Docs: dev summary, block-based plan, status updates, task list v2.0.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-20 23:09:27 +08:00

333 lines
8.9 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#' @tool_code ST_DESCRIPTIVE
#' @name 描述性统计
#' @version 1.0.0
#' @description 数据概况与基线特征表
#' @author SSA-Pro Team
library(glue)
library(ggplot2)
library(base64enc)
run_analysis <- function(input) {
# ===== 初始化 =====
logs <- c()
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
on.exit({}, add = TRUE)
# ===== 数据加载 =====
log_add("开始加载输入数据")
df <- tryCatch(
load_input_data(input),
error = function(e) {
log_add(paste("数据加载失败:", e$message))
return(NULL)
}
)
if (is.null(df)) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
}
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
p <- input$params
variables <- p$variables # 变量列表(可选,空则分析全部)
group_var <- p$group_var # 分组变量(可选)
# ===== 确定要分析的变量 =====
if (is.null(variables) || length(variables) == 0) {
variables <- names(df)
log_add("未指定变量,分析全部列")
}
# 排除分组变量本身
if (!is.null(group_var) && group_var %in% variables) {
variables <- setdiff(variables, group_var)
}
# 校验变量存在性
missing_vars <- setdiff(variables, names(df))
if (length(missing_vars) > 0) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND,
col = paste(missing_vars, collapse = ", ")))
}
# 校验分组变量
groups <- NULL
if (!is.null(group_var) && group_var != "") {
if (!(group_var %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = group_var))
}
groups <- unique(df[[group_var]][!is.na(df[[group_var]])])
log_add(glue("分组变量: {group_var}, 分组: {paste(groups, collapse=', ')}"))
}
# ===== 变量类型推断 =====
var_types <- sapply(variables, function(v) {
vals <- df[[v]]
if (is.numeric(vals)) {
non_na_count <- sum(!is.na(vals))
if (non_na_count == 0) {
return("categorical") # 全是 NA当作分类变量
}
unique_count <- length(unique(vals[!is.na(vals)]))
unique_ratio <- unique_count / non_na_count
if (unique_ratio < 0.05 && unique_count <= 10) {
return("categorical")
}
return("numeric")
} else {
return("categorical")
}
})
log_add(glue("数值变量: {sum(var_types == 'numeric')}, 分类变量: {sum(var_types == 'categorical')}"))
# ===== 计算描述性统计 =====
warnings_list <- c()
results_list <- list()
for (v in variables) {
var_type <- as.character(var_types[v])
if (is.na(var_type) || length(var_type) == 0) {
var_type <- "categorical" # 默认为分类变量
}
if (is.null(groups)) {
# 无分组
if (identical(var_type, "numeric")) {
stats <- calc_numeric_stats(df[[v]], v)
} else {
stats <- calc_categorical_stats(df[[v]], v)
}
stats$type <- var_type
results_list[[v]] <- stats
} else {
# 有分组
group_stats <- list()
for (g in groups) {
subset_vals <- df[df[[group_var]] == g, v, drop = TRUE]
if (identical(var_type, "numeric")) {
group_stats[[as.character(g)]] <- calc_numeric_stats(subset_vals, v)
} else {
group_stats[[as.character(g)]] <- calc_categorical_stats(subset_vals, v)
}
}
results_list[[v]] <- list(
variable = v,
type = var_type,
by_group = group_stats
)
}
}
# ===== 总体概况 =====
summary_stats <- list(
n_total = nrow(df),
n_variables = length(variables),
n_numeric = sum(var_types == "numeric"),
n_categorical = sum(var_types == "categorical")
)
if (!is.null(groups)) {
summary_stats$group_var <- group_var
summary_stats$groups <- lapply(groups, function(g) {
list(name = as.character(g), n = sum(df[[group_var]] == g, na.rm = TRUE))
})
}
# ===== 生成图表 =====
log_add("生成描述性统计图表")
plots <- list()
# 只为前几个变量生成图表(避免过多)
vars_to_plot <- head(variables, 4)
for (v in vars_to_plot) {
plot_base64 <- tryCatch({
if (var_types[v] == "numeric") {
generate_histogram(df, v, group_var)
} else {
generate_bar_chart(df, v, group_var)
}
}, error = function(e) {
log_add(paste("图表生成失败:", v, e$message))
NULL
})
if (!is.null(plot_base64)) {
plots <- c(plots, list(plot_base64))
}
}
# ===== 生成可复现代码 =====
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
input$original_filename
} else {
"data.csv"
}
reproducible_code <- glue('
# SSA-Pro 自动生成代码
# 工具: 描述性统计
# 时间: {Sys.time()}
# ================================
library(ggplot2)
# 数据准备
df <- read.csv("{original_filename}")
# 数值变量描述性统计
numeric_vars <- sapply(df, is.numeric)
if (any(numeric_vars)) {{
summary(df[, numeric_vars, drop = FALSE])
}}
# 分类变量频数表
categorical_vars <- !numeric_vars
if (any(categorical_vars)) {{
for (v in names(df)[categorical_vars]) {{
cat("\\n变量:", v, "\\n")
print(table(df[[v]], useNA = "ifany"))
}}
}}
# 可视化示例
# ggplot(df, aes(x = your_variable)) + geom_histogram()
')
# ===== 返回结果 =====
log_add("分析完成")
return(list(
status = "success",
message = "分析完成",
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
results = list(
summary = summary_stats,
variables = results_list
),
plots = plots,
trace_log = logs,
reproducible_code = as.character(reproducible_code)
))
}
# ===== 辅助函数 =====
# 数值变量统计
calc_numeric_stats <- function(vals, var_name) {
vals <- vals[!is.na(vals)]
n <- length(vals)
if (n == 0) {
return(list(
variable = var_name,
n = 0,
missing = length(vals) - n,
stats = NULL
))
}
list(
variable = var_name,
n = n,
missing = 0,
mean = round(mean(vals), 3),
sd = round(sd(vals), 3),
median = round(median(vals), 3),
q1 = round(quantile(vals, 0.25), 3),
q3 = round(quantile(vals, 0.75), 3),
iqr = round(IQR(vals), 3),
min = round(min(vals), 3),
max = round(max(vals), 3),
skewness = round(calc_skewness(vals), 3),
formatted = paste0(round(mean(vals), 2), " ± ", round(sd(vals), 2))
)
}
# 分类变量统计
calc_categorical_stats <- function(vals, var_name) {
total <- length(vals)
valid <- sum(!is.na(vals))
freq_table <- table(vals, useNA = "no")
levels_list <- lapply(names(freq_table), function(level) {
count <- as.numeric(freq_table[level])
pct <- round(count / valid * 100, 1)
list(
level = level,
n = count,
pct = pct,
formatted = paste0(count, " (", pct, "%)")
)
})
list(
variable = var_name,
n = valid,
missing = total - valid,
levels = levels_list
)
}
# 计算偏度
calc_skewness <- function(x) {
n <- length(x)
if (n < 3) return(NA)
m <- mean(x)
s <- sd(x)
sum((x - m)^3) / (n * s^3)
}
# 生成直方图
generate_histogram <- function(df, var_name, group_var = NULL) {
if (!is.null(group_var) && group_var != "") {
p <- ggplot(df[!is.na(df[[var_name]]), ], aes(x = .data[[var_name]], fill = factor(.data[[group_var]]))) +
geom_histogram(alpha = 0.6, position = "identity", bins = 30) +
scale_fill_brewer(palette = "Set1", name = group_var) +
theme_minimal()
} else {
p <- ggplot(df[!is.na(df[[var_name]]), ], aes(x = .data[[var_name]])) +
geom_histogram(fill = "#3b82f6", alpha = 0.7, bins = 30) +
theme_minimal()
}
p <- p + labs(title = paste("Distribution of", var_name), x = var_name, y = "Count")
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, p, width = 6, height = 4, dpi = 100)
base64_str <- base64encode(tmp_file)
unlink(tmp_file)
return(paste0("data:image/png;base64,", base64_str))
}
# 生成柱状图
generate_bar_chart <- function(df, var_name, group_var = NULL) {
df_plot <- df[!is.na(df[[var_name]]), ]
if (!is.null(group_var) && group_var != "") {
p <- ggplot(df_plot, aes(x = factor(.data[[var_name]]), fill = factor(.data[[group_var]]))) +
geom_bar(position = "dodge") +
scale_fill_brewer(palette = "Set1", name = group_var) +
theme_minimal()
} else {
p <- ggplot(df_plot, aes(x = factor(.data[[var_name]]))) +
geom_bar(fill = "#3b82f6", alpha = 0.7) +
theme_minimal()
}
p <- p + labs(title = paste("Frequency of", var_name), x = var_name, y = "Count") +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, p, width = 6, height = 4, dpi = 100)
base64_str <- base64encode(tmp_file)
unlink(tmp_file)
return(paste0("data:image/png;base64,", base64_str))
}