Bug fixes: - Fix garbled error messages in chat (TypeWriter rendering issue) - Fix R engine NA crash in descriptive.R (defensive isTRUE/is.na checks) - Fix intent misclassification for statistical significance queries - Fix step 2 results not displayed (accept warning status alongside success) - Fix incomplete R code download (only step 1 included) - Fix multi-task state confusion (clicking old card shows new results) - Add R engine and backend parameter logging for debugging Refactor - Unified Record Architecture: - Replace 12 global singleton fields with AnalysisRecord as single source of truth - Remove isWorkflowMode branching across all components - One Analysis = One Record = N Steps paradigm - selectRecord only sets currentRecordId, all rendering derives from currentRecord - Fix cross-hook-instance issue: executeWorkflow fallback to store currentRecordId Updated files: ssaStore, useWorkflow, useAnalysis, SSAChatPane, SSAWorkspacePane, SSACodeModal, WorkflowTimeline, QueryService, WorkflowExecutorService, descriptive.R Tested: Manual integration test passed - multi-task switching, R code completeness Co-authored-by: Cursor <cursoragent@cursor.com>
492 lines
15 KiB
R
492 lines
15 KiB
R
#' @tool_code ST_DESCRIPTIVE
|
|
#' @name 描述性统计
|
|
#' @version 1.0.0
|
|
#' @description 数据概况与基线特征表
|
|
#' @author SSA-Pro Team
|
|
|
|
library(glue)
|
|
library(ggplot2)
|
|
library(base64enc)
|
|
|
|
run_analysis <- function(input) {
|
|
# ===== 初始化 =====
|
|
logs <- c()
|
|
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
|
|
|
|
on.exit({}, add = TRUE)
|
|
|
|
# ===== 数据加载 =====
|
|
log_add("开始加载输入数据")
|
|
df <- tryCatch(
|
|
load_input_data(input),
|
|
error = function(e) {
|
|
log_add(paste("数据加载失败:", e$message))
|
|
return(NULL)
|
|
}
|
|
)
|
|
|
|
if (is.null(df)) {
|
|
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
|
|
}
|
|
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
|
|
|
|
p <- input$params
|
|
variables <- p$variables
|
|
group_var <- p$group_var
|
|
|
|
# Normalize group_var: ensure it's NULL or a valid non-empty string (never NA)
|
|
if (is.null(group_var) || length(group_var) == 0 || isTRUE(is.na(group_var)) || !nzchar(trimws(as.character(group_var[1])))) {
|
|
group_var <- NULL
|
|
} else {
|
|
group_var <- as.character(group_var[1])
|
|
}
|
|
|
|
log_add(glue("=== 输入参数 === variables: [{paste(variables, collapse=', ')}], group_var: {ifelse(is.null(group_var), 'NULL', group_var)}"))
|
|
log_add(glue("=== 数据列 === [{paste(names(df), collapse=', ')}]"))
|
|
|
|
# ===== 确定要分析的变量 =====
|
|
if (is.null(variables) || length(variables) == 0) {
|
|
variables <- names(df)
|
|
log_add("未指定变量,分析全部列")
|
|
}
|
|
variables <- as.character(variables)
|
|
|
|
# 排除分组变量本身
|
|
if (!is.null(group_var) && group_var %in% variables) {
|
|
variables <- setdiff(variables, group_var)
|
|
}
|
|
|
|
# 校验变量存在性
|
|
missing_vars <- setdiff(variables, names(df))
|
|
if (length(missing_vars) > 0) {
|
|
log_add(glue("缺失变量: [{paste(missing_vars, collapse=', ')}]"))
|
|
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND,
|
|
col = paste(missing_vars, collapse = ", ")))
|
|
}
|
|
log_add(glue("最终分析变量 ({length(variables)}): [{paste(variables, collapse=', ')}]"))
|
|
|
|
# 校验分组变量
|
|
groups <- NULL
|
|
if (!is.null(group_var)) {
|
|
if (!(group_var %in% names(df))) {
|
|
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = group_var))
|
|
}
|
|
groups <- unique(df[[group_var]][!is.na(df[[group_var]])])
|
|
log_add(glue("分组变量: {group_var}, 分组: {paste(groups, collapse=', ')}"))
|
|
}
|
|
|
|
# ===== 变量类型推断 =====
|
|
var_types <- tryCatch({
|
|
result <- sapply(variables, function(v) {
|
|
vals <- df[[v]]
|
|
if (is.null(vals)) return("categorical")
|
|
if (isTRUE(is.numeric(vals))) {
|
|
non_na_count <- sum(!is.na(vals))
|
|
if (non_na_count == 0) return("categorical")
|
|
unique_count <- length(unique(vals[!is.na(vals)]))
|
|
unique_ratio <- unique_count / non_na_count
|
|
if (isTRUE(unique_ratio < 0.05) && isTRUE(unique_count <= 10)) {
|
|
return("categorical")
|
|
}
|
|
return("numeric")
|
|
} else {
|
|
return("categorical")
|
|
}
|
|
})
|
|
if (is.null(names(result))) names(result) <- variables
|
|
result
|
|
}, error = function(e) {
|
|
log_add(paste("变量类型推断失败:", e$message))
|
|
setNames(rep("categorical", length(variables)), variables)
|
|
})
|
|
|
|
log_add(glue("数值变量: {sum(var_types == 'numeric', na.rm=TRUE)}, 分类变量: {sum(var_types == 'categorical', na.rm=TRUE)}"))
|
|
log_add(glue("var_types 详情: {paste(names(var_types), '=', var_types, collapse=', ')}"))
|
|
|
|
# ===== 计算描述性统计 =====
|
|
warnings_list <- c()
|
|
results_list <- list()
|
|
|
|
for (v in variables) {
|
|
var_type <- as.character(var_types[v])
|
|
if (is.na(var_type) || length(var_type) == 0) {
|
|
var_type <- "categorical" # 默认为分类变量
|
|
}
|
|
|
|
if (is.null(groups)) {
|
|
# 无分组
|
|
if (identical(var_type, "numeric")) {
|
|
stats <- calc_numeric_stats(df[[v]], v)
|
|
} else {
|
|
stats <- calc_categorical_stats(df[[v]], v)
|
|
}
|
|
stats$type <- var_type
|
|
results_list[[v]] <- stats
|
|
} else {
|
|
# 有分组
|
|
group_stats <- list()
|
|
for (g in groups) {
|
|
mask <- df[[group_var]] == g & !is.na(df[[group_var]])
|
|
subset_vals <- df[mask, v, drop = TRUE]
|
|
if (identical(var_type, "numeric")) {
|
|
group_stats[[as.character(g)]] <- calc_numeric_stats(subset_vals, v)
|
|
} else {
|
|
group_stats[[as.character(g)]] <- calc_categorical_stats(subset_vals, v)
|
|
}
|
|
}
|
|
results_list[[v]] <- list(
|
|
variable = v,
|
|
type = var_type,
|
|
by_group = group_stats
|
|
)
|
|
}
|
|
}
|
|
|
|
# ===== 总体概况 =====
|
|
summary_stats <- list(
|
|
n_total = nrow(df),
|
|
n_variables = length(variables),
|
|
n_numeric = sum(var_types == "numeric"),
|
|
n_categorical = sum(var_types == "categorical")
|
|
)
|
|
|
|
if (!is.null(groups)) {
|
|
summary_stats$group_var <- group_var
|
|
summary_stats$groups <- lapply(groups, function(g) {
|
|
list(name = as.character(g), n = sum(df[[group_var]] == g, na.rm = TRUE))
|
|
})
|
|
}
|
|
|
|
# ===== 生成图表 =====
|
|
log_add("生成描述性统计图表")
|
|
plots <- list()
|
|
|
|
# 只为前几个变量生成图表(避免过多)
|
|
vars_to_plot <- head(variables, 4)
|
|
|
|
for (v in vars_to_plot) {
|
|
plot_base64 <- tryCatch({
|
|
if (isTRUE(var_types[v] == "numeric")) {
|
|
generate_histogram(df, v, group_var)
|
|
} else {
|
|
generate_bar_chart(df, v, group_var)
|
|
}
|
|
}, error = function(e) {
|
|
log_add(paste("图表生成失败:", v, e$message))
|
|
NULL
|
|
})
|
|
|
|
if (!is.null(plot_base64)) {
|
|
plots <- c(plots, list(plot_base64))
|
|
}
|
|
}
|
|
|
|
# ===== 生成可复现代码 =====
|
|
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
|
|
input$original_filename
|
|
} else {
|
|
"data.csv"
|
|
}
|
|
|
|
# Build dynamic visualization code based on actual variables
|
|
plot_code_section <- tryCatch({
|
|
plot_code_lines <- c()
|
|
for (v in vars_to_plot) {
|
|
safe_v <- gsub('"', '\\\\"', v)
|
|
vt <- if (is.null(var_types) || is.na(var_types[v])) "categorical" else as.character(var_types[v])
|
|
safe_var_name <- gsub("[^a-zA-Z0-9]", "_", v)
|
|
if (vt == "numeric") {
|
|
if (!is.null(group_var) && group_var != "") {
|
|
safe_g <- gsub('"', '\\\\"', group_var)
|
|
plot_code_lines <- c(plot_code_lines, glue('
|
|
# Histogram: {safe_v}
|
|
p_{safe_var_name} <- ggplot(df[!is.na(df[["{safe_v}"]]), ], aes(x = .data[["{safe_v}"]], fill = factor(.data[["{safe_g}"]]))) +
|
|
geom_histogram(alpha = 0.6, position = "identity", bins = 30) +
|
|
scale_fill_brewer(palette = "Set1", name = "{safe_g}") +
|
|
labs(title = "Distribution of {safe_v}", x = "{safe_v}", y = "Count") +
|
|
theme_minimal()
|
|
print(p_{safe_var_name})
|
|
'))
|
|
} else {
|
|
plot_code_lines <- c(plot_code_lines, glue('
|
|
# Histogram: {safe_v}
|
|
p_{safe_var_name} <- ggplot(df[!is.na(df[["{safe_v}"]]), ], aes(x = .data[["{safe_v}"]])) +
|
|
geom_histogram(fill = "#3b82f6", alpha = 0.7, bins = 30) +
|
|
labs(title = "Distribution of {safe_v}", x = "{safe_v}", y = "Count") +
|
|
theme_minimal()
|
|
print(p_{safe_var_name})
|
|
'))
|
|
}
|
|
} else {
|
|
if (!is.null(group_var) && group_var != "") {
|
|
safe_g <- gsub('"', '\\\\"', group_var)
|
|
plot_code_lines <- c(plot_code_lines, glue('
|
|
# Bar chart: {safe_v}
|
|
p_{safe_var_name} <- ggplot(df[!is.na(df[["{safe_v}"]]), ], aes(x = factor(.data[["{safe_v}"]]), fill = factor(.data[["{safe_g}"]]))) +
|
|
geom_bar(position = "dodge") +
|
|
scale_fill_brewer(palette = "Set1", name = "{safe_g}") +
|
|
labs(title = "Frequency of {safe_v}", x = "{safe_v}", y = "Count") +
|
|
theme_minimal() +
|
|
theme(axis.text.x = element_text(angle = 45, hjust = 1))
|
|
print(p_{safe_var_name})
|
|
'))
|
|
} else {
|
|
plot_code_lines <- c(plot_code_lines, glue('
|
|
# Bar chart: {safe_v}
|
|
p_{safe_var_name} <- ggplot(df[!is.na(df[["{safe_v}"]]), ], aes(x = factor(.data[["{safe_v}"]]))) +
|
|
geom_bar(fill = "#3b82f6", alpha = 0.7) +
|
|
labs(title = "Frequency of {safe_v}", x = "{safe_v}", y = "Count") +
|
|
theme_minimal() +
|
|
theme(axis.text.x = element_text(angle = 45, hjust = 1))
|
|
print(p_{safe_var_name})
|
|
'))
|
|
}
|
|
}
|
|
}
|
|
paste(plot_code_lines, collapse = "\n")
|
|
}, error = function(e) {
|
|
log_add(paste("reproducible_code visualization generation failed:", e$message))
|
|
"# ggplot(df, aes(x = your_variable)) + geom_histogram()"
|
|
})
|
|
|
|
reproducible_code <- glue('
|
|
# SSA-Pro 自动生成代码
|
|
# 工具: 描述性统计
|
|
# 时间: {Sys.time()}
|
|
# ================================
|
|
|
|
library(ggplot2)
|
|
|
|
# 数据准备
|
|
df <- read.csv("{original_filename}")
|
|
|
|
# 数值变量描述性统计
|
|
numeric_vars <- sapply(df, is.numeric)
|
|
if (any(numeric_vars)) {{
|
|
print(summary(df[, numeric_vars, drop = FALSE]))
|
|
}}
|
|
|
|
# 分类变量频数表
|
|
categorical_vars <- !numeric_vars
|
|
if (any(categorical_vars)) {{
|
|
for (v in names(df)[categorical_vars]) {{
|
|
cat("\\n变量:", v, "\\n")
|
|
print(table(df[[v]], useNA = "ifany"))
|
|
}}
|
|
}}
|
|
|
|
# ======== 可视化 ========
|
|
{plot_code_section}
|
|
')
|
|
|
|
# ===== 返回结果 =====
|
|
log_add("分析完成")
|
|
|
|
# ===== 构建 report_blocks =====
|
|
blocks <- list()
|
|
|
|
# Block 1: 数据概况
|
|
kv_items <- list(
|
|
"总样本量" = as.character(summary_stats$n_total),
|
|
"变量数" = as.character(summary_stats$n_variables),
|
|
"数值变量数" = as.character(summary_stats$n_numeric),
|
|
"分类变量数" = as.character(summary_stats$n_categorical)
|
|
)
|
|
if (!is.null(groups)) {
|
|
kv_items$group_var <- group_var
|
|
kv_items$groups <- paste(sapply(summary_stats$groups, function(g) paste0(g$name, "(n=", g$n, ")")), collapse = ", ")
|
|
}
|
|
blocks[[length(blocks) + 1]] <- make_kv_block(kv_items, title = "数据概况")
|
|
|
|
# Block 2: 数值变量汇总表
|
|
numeric_vars <- names(results_list)[sapply(results_list, function(x) {
|
|
if (is.list(x) && !is.null(x$type)) x$type == "numeric" else FALSE
|
|
})]
|
|
if (length(numeric_vars) > 0) {
|
|
if (is.null(groups)) {
|
|
num_headers <- c("变量名", "n", "mean", "sd", "median", "Q1", "Q3", "min", "max")
|
|
num_rows <- lapply(numeric_vars, function(v) {
|
|
s <- results_list[[v]]
|
|
c(v, as.character(s$n), as.character(s$mean), as.character(s$sd),
|
|
as.character(s$median), as.character(s$q1), as.character(s$q3),
|
|
as.character(s$min), as.character(s$max))
|
|
})
|
|
} else {
|
|
num_headers <- c("变量名", as.character(groups))
|
|
num_rows <- lapply(numeric_vars, function(v) {
|
|
s <- results_list[[v]]
|
|
row <- c(v)
|
|
for (g in groups) {
|
|
gs <- s$by_group[[as.character(g)]]
|
|
row <- c(row, if (!is.null(gs$formatted)) gs$formatted else "-")
|
|
}
|
|
row
|
|
})
|
|
}
|
|
blocks[[length(blocks) + 1]] <- make_table_block(num_headers, num_rows, title = "数值变量汇总表")
|
|
}
|
|
|
|
# Block 3: 分类变量汇总表
|
|
cat_vars <- names(results_list)[sapply(results_list, function(x) {
|
|
if (is.list(x) && !is.null(x$type)) x$type == "categorical" else FALSE
|
|
})]
|
|
if (length(cat_vars) > 0) {
|
|
cat_headers <- c("变量名", "水平", "n", "百分比")
|
|
cat_rows <- list()
|
|
for (v in cat_vars) {
|
|
s <- results_list[[v]]
|
|
if (is.null(groups)) {
|
|
for (lev in s$levels) {
|
|
cat_rows[[length(cat_rows) + 1]] <- c(v, lev$level, as.character(lev$n), paste0(lev$pct, "%"))
|
|
}
|
|
} else {
|
|
for (g in groups) {
|
|
gs <- s$by_group[[as.character(g)]]
|
|
for (lev in gs$levels) {
|
|
cat_rows[[length(cat_rows) + 1]] <- c(paste0(v, " (", g, ")"), lev$level, as.character(lev$n), paste0(lev$pct, "%"))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (length(cat_rows) > 0) {
|
|
blocks[[length(blocks) + 1]] <- make_table_block(cat_headers, cat_rows, title = "分类变量汇总表")
|
|
}
|
|
}
|
|
|
|
# Block 4+: 各图表
|
|
for (i in seq_along(plots)) {
|
|
blocks[[length(blocks) + 1]] <- make_image_block(plots[[i]], title = paste0("图表 ", i), alt = paste0("描述性统计图 ", i))
|
|
}
|
|
|
|
return(list(
|
|
status = "success",
|
|
message = "分析完成",
|
|
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
|
|
results = list(
|
|
summary = summary_stats,
|
|
variables = results_list
|
|
),
|
|
report_blocks = blocks,
|
|
plots = plots,
|
|
trace_log = logs,
|
|
reproducible_code = as.character(reproducible_code)
|
|
))
|
|
}
|
|
|
|
# ===== 辅助函数 =====
|
|
|
|
# 数值变量统计
|
|
calc_numeric_stats <- function(vals, var_name) {
|
|
vals <- vals[!is.na(vals)]
|
|
n <- length(vals)
|
|
|
|
if (n == 0) {
|
|
return(list(
|
|
variable = var_name,
|
|
n = 0,
|
|
missing = length(vals) - n,
|
|
stats = NULL
|
|
))
|
|
}
|
|
|
|
list(
|
|
variable = var_name,
|
|
n = n,
|
|
missing = 0,
|
|
mean = round(mean(vals), 3),
|
|
sd = round(sd(vals), 3),
|
|
median = round(median(vals), 3),
|
|
q1 = round(quantile(vals, 0.25), 3),
|
|
q3 = round(quantile(vals, 0.75), 3),
|
|
iqr = round(IQR(vals), 3),
|
|
min = round(min(vals), 3),
|
|
max = round(max(vals), 3),
|
|
skewness = round(calc_skewness(vals), 3),
|
|
formatted = paste0(round(mean(vals), 2), " ± ", round(sd(vals), 2))
|
|
)
|
|
}
|
|
|
|
# 分类变量统计
|
|
calc_categorical_stats <- function(vals, var_name) {
|
|
total <- length(vals)
|
|
valid <- sum(!is.na(vals))
|
|
|
|
freq_table <- table(vals, useNA = "no")
|
|
|
|
levels_list <- lapply(names(freq_table), function(level) {
|
|
count <- as.numeric(freq_table[level])
|
|
pct <- round(count / valid * 100, 1)
|
|
list(
|
|
level = level,
|
|
n = count,
|
|
pct = pct,
|
|
formatted = paste0(count, " (", pct, "%)")
|
|
)
|
|
})
|
|
|
|
list(
|
|
variable = var_name,
|
|
n = valid,
|
|
missing = total - valid,
|
|
levels = levels_list
|
|
)
|
|
}
|
|
|
|
# 计算偏度
|
|
calc_skewness <- function(x) {
|
|
n <- length(x)
|
|
if (n < 3) return(NA)
|
|
m <- mean(x)
|
|
s <- sd(x)
|
|
sum((x - m)^3) / (n * s^3)
|
|
}
|
|
|
|
# 生成直方图
|
|
generate_histogram <- function(df, var_name, group_var = NULL) {
|
|
if (!is.null(group_var) && group_var != "") {
|
|
p <- ggplot(df[!is.na(df[[var_name]]), ], aes(x = .data[[var_name]], fill = factor(.data[[group_var]]))) +
|
|
geom_histogram(alpha = 0.6, position = "identity", bins = 30) +
|
|
scale_fill_brewer(palette = "Set1", name = group_var) +
|
|
theme_minimal()
|
|
} else {
|
|
p <- ggplot(df[!is.na(df[[var_name]]), ], aes(x = .data[[var_name]])) +
|
|
geom_histogram(fill = "#3b82f6", alpha = 0.7, bins = 30) +
|
|
theme_minimal()
|
|
}
|
|
|
|
p <- p + labs(title = paste("Distribution of", var_name), x = var_name, y = "Count")
|
|
|
|
tmp_file <- tempfile(fileext = ".png")
|
|
ggsave(tmp_file, p, width = 6, height = 4, dpi = 100)
|
|
base64_str <- base64encode(tmp_file)
|
|
unlink(tmp_file)
|
|
|
|
return(paste0("data:image/png;base64,", base64_str))
|
|
}
|
|
|
|
# 生成柱状图
|
|
generate_bar_chart <- function(df, var_name, group_var = NULL) {
|
|
df_plot <- df[!is.na(df[[var_name]]), ]
|
|
|
|
if (!is.null(group_var) && group_var != "") {
|
|
p <- ggplot(df_plot, aes(x = factor(.data[[var_name]]), fill = factor(.data[[group_var]]))) +
|
|
geom_bar(position = "dodge") +
|
|
scale_fill_brewer(palette = "Set1", name = group_var) +
|
|
theme_minimal()
|
|
} else {
|
|
p <- ggplot(df_plot, aes(x = factor(.data[[var_name]]))) +
|
|
geom_bar(fill = "#3b82f6", alpha = 0.7) +
|
|
theme_minimal()
|
|
}
|
|
|
|
p <- p + labs(title = paste("Frequency of", var_name), x = var_name, y = "Count") +
|
|
theme(axis.text.x = element_text(angle = 45, hjust = 1))
|
|
|
|
tmp_file <- tempfile(fileext = ".png")
|
|
ggsave(tmp_file, p, width = 6, height = 4, dpi = 100)
|
|
base64_str <- base64encode(tmp_file)
|
|
unlink(tmp_file)
|
|
|
|
return(paste0("data:image/png;base64,", base64_str))
|
|
}
|