Phase I - Session Blackboard + READ Layer: - SessionBlackboardService with Postgres-Only cache - DataProfileService for data overview generation - PicoInferenceService for LLM-driven PICO extraction - Frontend DataContextCard and VariableDictionaryPanel - E2E tests: 31/31 passed Phase II - Conversation Layer LLM + Intent Router: - ConversationService with SSE streaming - IntentRouterService (rule-first + LLM fallback, 6 intents) - SystemPromptService with 6-segment dynamic assembly - TokenTruncationService for context management - ChatHandlerService as unified chat entry - Frontend SSAChatPane and useSSAChat hook - E2E tests: 38/38 passed Phase III - Method Consultation + AskUser Standardization: - ToolRegistryService with Repository Pattern - MethodConsultService with DecisionTable + LLM enhancement - AskUserService with global interrupt handling - Frontend AskUserCard component - E2E tests: 13/13 passed Phase IV - Dialogue-Driven Analysis + QPER Integration: - ToolOrchestratorService (plan/execute/report) - analysis_plan SSE event for WorkflowPlan transmission - Dual-channel confirmation (ask_user card + workspace button) - PICO as optional hint for LLM parsing - E2E tests: 25/25 passed R Statistics Service: - 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon - Enhanced guardrails and block helpers - Comprehensive test suite (run_all_tools_test.js) Documentation: - Updated system status document (v5.9) - Updated SSA module status and development plan (v1.8) Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25) Co-authored-by: Cursor <cursoragent@cursor.com>
317 lines
9.5 KiB
R
317 lines
9.5 KiB
R
#' @tool_code ST_BASELINE_TABLE
|
||
#' @name 基线特征表(复合工具)
|
||
#' @version 1.0.0
|
||
#' @description 基于 gtsummary 的一键式基线特征表生成,自动判断变量类型、选择统计方法、输出标准三线表
|
||
#' @author SSA-Pro Team
|
||
#' @note 复合工具:一次遍历所有变量,自动选方法(T/Wilcoxon/χ²/Fisher),合并出表
|
||
|
||
library(glue)
|
||
library(ggplot2)
|
||
library(base64enc)
|
||
|
||
run_analysis <- function(input) {
|
||
# ===== 初始化 =====
|
||
logs <- c()
|
||
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
|
||
warnings_list <- c()
|
||
|
||
on.exit({}, add = TRUE)
|
||
|
||
# ===== 依赖检查 =====
|
||
required_pkgs <- c("gtsummary", "gt", "broom")
|
||
for (pkg in required_pkgs) {
|
||
if (!requireNamespace(pkg, quietly = TRUE)) {
|
||
return(make_error(ERROR_CODES$E101_PACKAGE_MISSING, package = pkg))
|
||
}
|
||
}
|
||
|
||
library(gtsummary)
|
||
library(dplyr)
|
||
|
||
# ===== 数据加载 =====
|
||
log_add("开始加载输入数据")
|
||
df <- tryCatch(
|
||
load_input_data(input),
|
||
error = function(e) {
|
||
log_add(paste("数据加载失败:", e$message))
|
||
return(NULL)
|
||
}
|
||
)
|
||
|
||
if (is.null(df)) {
|
||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
|
||
}
|
||
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
|
||
|
||
p <- input$params
|
||
group_var <- p$group_var
|
||
analyze_vars <- as.character(unlist(p$analyze_vars))
|
||
|
||
# ===== 参数校验 =====
|
||
if (is.null(group_var) || !(group_var %in% names(df))) {
|
||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = group_var %||% "NULL"))
|
||
}
|
||
|
||
if (is.null(analyze_vars) || length(analyze_vars) == 0) {
|
||
analyze_vars <- setdiff(names(df), group_var)
|
||
log_add(glue("未指定分析变量,自动选取全部 {length(analyze_vars)} 个变量"))
|
||
}
|
||
|
||
missing_vars <- analyze_vars[!(analyze_vars %in% names(df))]
|
||
if (length(missing_vars) > 0) {
|
||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND,
|
||
col = paste(missing_vars, collapse = ", ")))
|
||
}
|
||
|
||
# ===== 数据清洗 =====
|
||
original_rows <- nrow(df)
|
||
df <- df[!is.na(df[[group_var]]) & trimws(as.character(df[[group_var]])) != "", ]
|
||
removed_rows <- original_rows - nrow(df)
|
||
if (removed_rows > 0) {
|
||
log_add(glue("分组变量缺失值清洗: 移除 {removed_rows} 行 (剩余 {nrow(df)} 行)"))
|
||
}
|
||
|
||
groups <- unique(df[[group_var]])
|
||
n_groups <- length(groups)
|
||
if (n_groups < 2) {
|
||
return(make_error(ERROR_CODES$E003_INSUFFICIENT_GROUPS,
|
||
col = group_var, expected = "2+", actual = n_groups))
|
||
}
|
||
|
||
# 样本量检查
|
||
sample_check <- check_sample_size(nrow(df), min_required = 10, action = ACTION_BLOCK)
|
||
if (!sample_check$passed) {
|
||
return(list(status = "blocked", message = sample_check$reason, trace_log = logs))
|
||
}
|
||
|
||
# 确保分组变量是因子
|
||
df[[group_var]] <- as.factor(df[[group_var]])
|
||
|
||
# 选取分析列
|
||
df_analysis <- df[, c(group_var, analyze_vars), drop = FALSE]
|
||
|
||
log_add(glue("分组变量: {group_var} ({n_groups} 组: {paste(groups, collapse=', ')})"))
|
||
log_add(glue("分析变量: {length(analyze_vars)} 个"))
|
||
|
||
# ===== 核心计算:gtsummary =====
|
||
log_add("使用 gtsummary 生成基线特征表")
|
||
|
||
tbl <- tryCatch(
|
||
withCallingHandlers(
|
||
{
|
||
tbl_summary(
|
||
df_analysis,
|
||
by = all_of(group_var),
|
||
missing = "ifany",
|
||
statistic = list(
|
||
all_continuous() ~ "{mean} ({sd})",
|
||
all_categorical() ~ "{n} ({p}%)"
|
||
),
|
||
digits = list(
|
||
all_continuous() ~ 2,
|
||
all_categorical() ~ c(0, 1)
|
||
)
|
||
) %>%
|
||
add_p() %>%
|
||
add_overall()
|
||
},
|
||
warning = function(w) {
|
||
warnings_list <<- c(warnings_list, w$message)
|
||
log_add(paste("gtsummary 警告:", w$message))
|
||
invokeRestart("muffleWarning")
|
||
}
|
||
),
|
||
error = function(e) {
|
||
log_add(paste("gtsummary 生成失败:", e$message))
|
||
return(NULL)
|
||
}
|
||
)
|
||
|
||
if (is.null(tbl)) {
|
||
return(map_r_error("gtsummary 基线特征表生成失败"))
|
||
}
|
||
|
||
log_add("gtsummary 表格生成成功")
|
||
|
||
# ===== 提取结构化数据 =====
|
||
tbl_df <- as.data.frame(tbl$table_body)
|
||
|
||
# 提取显著变量列表
|
||
significant_vars <- extract_significant_vars(tbl, alpha = 0.05)
|
||
log_add(glue("显著变量 (P < 0.05): {length(significant_vars)} 个"))
|
||
|
||
# 提取每个变量使用的统计方法
|
||
method_info <- extract_method_info(tbl)
|
||
|
||
# ===== 转换为 report_blocks =====
|
||
log_add("转换 gtsummary → report_blocks")
|
||
blocks <- gtsummary_to_blocks(tbl, group_var, groups, analyze_vars, significant_vars)
|
||
|
||
# ===== 构建结构化结果 =====
|
||
output_results <- list(
|
||
method = "gtsummary::tbl_summary + add_p",
|
||
group_var = group_var,
|
||
n_groups = n_groups,
|
||
groups = lapply(groups, function(g) {
|
||
list(label = as.character(g), n = sum(df[[group_var]] == g))
|
||
}),
|
||
n_variables = length(analyze_vars),
|
||
significant_vars = significant_vars,
|
||
method_info = method_info,
|
||
total_n = nrow(df)
|
||
)
|
||
|
||
# ===== 生成可复现代码 =====
|
||
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
|
||
input$original_filename
|
||
} else {
|
||
"data.csv"
|
||
}
|
||
|
||
vars_str <- paste0('c("', paste(analyze_vars, collapse = '", "'), '")')
|
||
|
||
reproducible_code <- glue('
|
||
# SSA-Pro 自动生成代码
|
||
# 工具: 基线特征表 (gtsummary)
|
||
# 时间: {Sys.time()}
|
||
# ================================
|
||
|
||
# 自动安装依赖
|
||
required_packages <- c("gtsummary", "gt", "dplyr")
|
||
new_packages <- required_packages[!(required_packages %in% installed.packages()[,"Package"])]
|
||
if(length(new_packages)) install.packages(new_packages, repos = "https://cloud.r-project.org")
|
||
|
||
library(gtsummary)
|
||
library(dplyr)
|
||
|
||
# 数据准备
|
||
df <- read.csv("{original_filename}")
|
||
group_var <- "{group_var}"
|
||
analyze_vars <- {vars_str}
|
||
|
||
df_analysis <- df[, c(group_var, analyze_vars)]
|
||
df_analysis[[group_var]] <- as.factor(df_analysis[[group_var]])
|
||
|
||
# 生成基线特征表
|
||
tbl <- tbl_summary(
|
||
df_analysis,
|
||
by = all_of(group_var),
|
||
missing = "ifany",
|
||
statistic = list(
|
||
all_continuous() ~ "{{mean}} ({{sd}})",
|
||
all_categorical() ~ "{{n}} ({{p}}%)"
|
||
)
|
||
) %>%
|
||
add_p() %>%
|
||
add_overall()
|
||
|
||
# 显示结果
|
||
tbl
|
||
|
||
# 导出为 Word(可选)
|
||
# tbl %>% as_gt() %>% gt::gtsave("baseline_table.docx")
|
||
')
|
||
|
||
# ===== 返回结果 =====
|
||
log_add("分析完成")
|
||
|
||
return(list(
|
||
status = "success",
|
||
message = "基线特征表生成完成",
|
||
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
|
||
results = output_results,
|
||
report_blocks = blocks,
|
||
plots = list(),
|
||
trace_log = logs,
|
||
reproducible_code = as.character(reproducible_code)
|
||
))
|
||
}
|
||
|
||
# ===== gtsummary → report_blocks 转换层 =====
|
||
|
||
#' 将 gtsummary 表格转为 report_blocks
|
||
gtsummary_to_blocks <- function(tbl, group_var, groups, analyze_vars, significant_vars) {
|
||
blocks <- list()
|
||
|
||
# 提取 tibble 格式
|
||
tbl_data <- gtsummary::as_tibble(tbl, col_labels = FALSE)
|
||
|
||
# Block 1: 三线表(核心输出)
|
||
headers <- colnames(tbl_data)
|
||
rows <- lapply(seq_len(nrow(tbl_data)), function(i) {
|
||
row <- as.list(tbl_data[i, ])
|
||
lapply(row, function(cell) {
|
||
val <- as.character(cell)
|
||
if (is.na(val) || val == "NA") "" else val
|
||
})
|
||
})
|
||
|
||
# 标记 P < 0.05 的行
|
||
p_col_idx <- which(grepl("p.value|p_value", headers, ignore.case = TRUE))
|
||
|
||
blocks[[length(blocks) + 1]] <- make_table_block(
|
||
headers, rows,
|
||
title = glue("基线特征表 (按 {group_var} 分组)"),
|
||
footnote = "连续变量: Mean (SD); 分类变量: N (%); P 值由自动选择的统计方法计算",
|
||
metadata = list(
|
||
is_baseline_table = TRUE,
|
||
group_var = group_var,
|
||
has_p_values = length(p_col_idx) > 0
|
||
)
|
||
)
|
||
|
||
# Block 2: 样本量概况
|
||
group_n_items <- lapply(groups, function(g) {
|
||
list(key = as.character(g), value = "—")
|
||
})
|
||
blocks[[length(blocks) + 1]] <- make_kv_block(
|
||
list("总样本量" = as.character(nrow(tbl$inputs$data)),
|
||
"分组变量" = group_var,
|
||
"分组数" = as.character(length(groups)),
|
||
"分析变量数" = as.character(length(analyze_vars))),
|
||
title = "样本概况"
|
||
)
|
||
|
||
# Block 3: 显著变量摘要
|
||
if (length(significant_vars) > 0) {
|
||
conclusion <- glue("在 α = 0.05 水平下,以下变量在组间存在显著差异:**{paste(significant_vars, collapse = '**、**')}**(共 {length(significant_vars)} 个)。")
|
||
} else {
|
||
conclusion <- "在 α = 0.05 水平下,未发现各组间存在显著差异的基线变量。"
|
||
}
|
||
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "组间差异摘要")
|
||
|
||
return(blocks)
|
||
}
|
||
|
||
#' 从 gtsummary 提取显著变量
|
||
extract_significant_vars <- function(tbl, alpha = 0.05) {
|
||
body <- tbl$table_body
|
||
p_vals <- as.numeric(unlist(body$p.value))
|
||
vars <- as.character(body$variable)
|
||
sig_idx <- which(!is.na(p_vals) & p_vals < alpha)
|
||
if (length(sig_idx) == 0) return(character(0))
|
||
unique(vars[sig_idx])
|
||
}
|
||
|
||
#' 提取每个变量使用的统计方法
|
||
extract_method_info <- function(tbl) {
|
||
body <- tbl$table_body
|
||
p_vals <- as.numeric(unlist(body$p.value))
|
||
has_p <- which(!is.na(p_vals))
|
||
if (length(has_p) == 0) return(list())
|
||
|
||
test_names <- if ("test_name" %in% colnames(body)) as.character(unlist(body$test_name)) else rep("unknown", nrow(body))
|
||
|
||
lapply(has_p, function(i) {
|
||
list(
|
||
variable = as.character(body$variable[i]),
|
||
test_name = test_names[i] %||% "unknown",
|
||
p_value = round(p_vals[i], 4),
|
||
p_value_fmt = format_p_value(p_vals[i])
|
||
)
|
||
})
|
||
}
|
||
|
||
# NULL 合并运算符
|
||
`%||%` <- function(x, y) if (is.null(x)) y else x
|