Files
HaHafeng 3446909ff7 feat(ssa): Complete Phase I-IV intelligent dialogue and tool system development
Phase I - Session Blackboard + READ Layer:
- SessionBlackboardService with Postgres-Only cache
- DataProfileService for data overview generation
- PicoInferenceService for LLM-driven PICO extraction
- Frontend DataContextCard and VariableDictionaryPanel
- E2E tests: 31/31 passed

Phase II - Conversation Layer LLM + Intent Router:
- ConversationService with SSE streaming
- IntentRouterService (rule-first + LLM fallback, 6 intents)
- SystemPromptService with 6-segment dynamic assembly
- TokenTruncationService for context management
- ChatHandlerService as unified chat entry
- Frontend SSAChatPane and useSSAChat hook
- E2E tests: 38/38 passed

Phase III - Method Consultation + AskUser Standardization:
- ToolRegistryService with Repository Pattern
- MethodConsultService with DecisionTable + LLM enhancement
- AskUserService with global interrupt handling
- Frontend AskUserCard component
- E2E tests: 13/13 passed

Phase IV - Dialogue-Driven Analysis + QPER Integration:
- ToolOrchestratorService (plan/execute/report)
- analysis_plan SSE event for WorkflowPlan transmission
- Dual-channel confirmation (ask_user card + workspace button)
- PICO as optional hint for LLM parsing
- E2E tests: 25/25 passed

R Statistics Service:
- 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon
- Enhanced guardrails and block helpers
- Comprehensive test suite (run_all_tools_test.js)

Documentation:
- Updated system status document (v5.9)
- Updated SSA module status and development plan (v1.8)

Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-22 18:53:39 +08:00

317 lines
9.5 KiB
R
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#' @tool_code ST_BASELINE_TABLE
#' @name 基线特征表(复合工具)
#' @version 1.0.0
#' @description 基于 gtsummary 的一键式基线特征表生成,自动判断变量类型、选择统计方法、输出标准三线表
#' @author SSA-Pro Team
#' @note 复合工具一次遍历所有变量自动选方法T/Wilcoxon/χ²/Fisher合并出表
library(glue)
library(ggplot2)
library(base64enc)
run_analysis <- function(input) {
# ===== 初始化 =====
logs <- c()
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
warnings_list <- c()
on.exit({}, add = TRUE)
# ===== 依赖检查 =====
required_pkgs <- c("gtsummary", "gt", "broom")
for (pkg in required_pkgs) {
if (!requireNamespace(pkg, quietly = TRUE)) {
return(make_error(ERROR_CODES$E101_PACKAGE_MISSING, package = pkg))
}
}
library(gtsummary)
library(dplyr)
# ===== 数据加载 =====
log_add("开始加载输入数据")
df <- tryCatch(
load_input_data(input),
error = function(e) {
log_add(paste("数据加载失败:", e$message))
return(NULL)
}
)
if (is.null(df)) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
}
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
p <- input$params
group_var <- p$group_var
analyze_vars <- as.character(unlist(p$analyze_vars))
# ===== 参数校验 =====
if (is.null(group_var) || !(group_var %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = group_var %||% "NULL"))
}
if (is.null(analyze_vars) || length(analyze_vars) == 0) {
analyze_vars <- setdiff(names(df), group_var)
log_add(glue("未指定分析变量,自动选取全部 {length(analyze_vars)} 个变量"))
}
missing_vars <- analyze_vars[!(analyze_vars %in% names(df))]
if (length(missing_vars) > 0) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND,
col = paste(missing_vars, collapse = ", ")))
}
# ===== 数据清洗 =====
original_rows <- nrow(df)
df <- df[!is.na(df[[group_var]]) & trimws(as.character(df[[group_var]])) != "", ]
removed_rows <- original_rows - nrow(df)
if (removed_rows > 0) {
log_add(glue("分组变量缺失值清洗: 移除 {removed_rows} 行 (剩余 {nrow(df)} 行)"))
}
groups <- unique(df[[group_var]])
n_groups <- length(groups)
if (n_groups < 2) {
return(make_error(ERROR_CODES$E003_INSUFFICIENT_GROUPS,
col = group_var, expected = "2+", actual = n_groups))
}
# 样本量检查
sample_check <- check_sample_size(nrow(df), min_required = 10, action = ACTION_BLOCK)
if (!sample_check$passed) {
return(list(status = "blocked", message = sample_check$reason, trace_log = logs))
}
# 确保分组变量是因子
df[[group_var]] <- as.factor(df[[group_var]])
# 选取分析列
df_analysis <- df[, c(group_var, analyze_vars), drop = FALSE]
log_add(glue("分组变量: {group_var} ({n_groups} 组: {paste(groups, collapse=', ')})"))
log_add(glue("分析变量: {length(analyze_vars)} 个"))
# ===== 核心计算gtsummary =====
log_add("使用 gtsummary 生成基线特征表")
tbl <- tryCatch(
withCallingHandlers(
{
tbl_summary(
df_analysis,
by = all_of(group_var),
missing = "ifany",
statistic = list(
all_continuous() ~ "{mean} ({sd})",
all_categorical() ~ "{n} ({p}%)"
),
digits = list(
all_continuous() ~ 2,
all_categorical() ~ c(0, 1)
)
) %>%
add_p() %>%
add_overall()
},
warning = function(w) {
warnings_list <<- c(warnings_list, w$message)
log_add(paste("gtsummary 警告:", w$message))
invokeRestart("muffleWarning")
}
),
error = function(e) {
log_add(paste("gtsummary 生成失败:", e$message))
return(NULL)
}
)
if (is.null(tbl)) {
return(map_r_error("gtsummary 基线特征表生成失败"))
}
log_add("gtsummary 表格生成成功")
# ===== 提取结构化数据 =====
tbl_df <- as.data.frame(tbl$table_body)
# 提取显著变量列表
significant_vars <- extract_significant_vars(tbl, alpha = 0.05)
log_add(glue("显著变量 (P < 0.05): {length(significant_vars)} 个"))
# 提取每个变量使用的统计方法
method_info <- extract_method_info(tbl)
# ===== 转换为 report_blocks =====
log_add("转换 gtsummary → report_blocks")
blocks <- gtsummary_to_blocks(tbl, group_var, groups, analyze_vars, significant_vars)
# ===== 构建结构化结果 =====
output_results <- list(
method = "gtsummary::tbl_summary + add_p",
group_var = group_var,
n_groups = n_groups,
groups = lapply(groups, function(g) {
list(label = as.character(g), n = sum(df[[group_var]] == g))
}),
n_variables = length(analyze_vars),
significant_vars = significant_vars,
method_info = method_info,
total_n = nrow(df)
)
# ===== 生成可复现代码 =====
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
input$original_filename
} else {
"data.csv"
}
vars_str <- paste0('c("', paste(analyze_vars, collapse = '", "'), '")')
reproducible_code <- glue('
# SSA-Pro 自动生成代码
# 工具: 基线特征表 (gtsummary)
# 时间: {Sys.time()}
# ================================
# 自动安装依赖
required_packages <- c("gtsummary", "gt", "dplyr")
new_packages <- required_packages[!(required_packages %in% installed.packages()[,"Package"])]
if(length(new_packages)) install.packages(new_packages, repos = "https://cloud.r-project.org")
library(gtsummary)
library(dplyr)
# 数据准备
df <- read.csv("{original_filename}")
group_var <- "{group_var}"
analyze_vars <- {vars_str}
df_analysis <- df[, c(group_var, analyze_vars)]
df_analysis[[group_var]] <- as.factor(df_analysis[[group_var]])
# 生成基线特征表
tbl <- tbl_summary(
df_analysis,
by = all_of(group_var),
missing = "ifany",
statistic = list(
all_continuous() ~ "{{mean}} ({{sd}})",
all_categorical() ~ "{{n}} ({{p}}%)"
)
) %>%
add_p() %>%
add_overall()
# 显示结果
tbl
# 导出为 Word可选
# tbl %>% as_gt() %>% gt::gtsave("baseline_table.docx")
')
# ===== 返回结果 =====
log_add("分析完成")
return(list(
status = "success",
message = "基线特征表生成完成",
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
results = output_results,
report_blocks = blocks,
plots = list(),
trace_log = logs,
reproducible_code = as.character(reproducible_code)
))
}
# ===== gtsummary → report_blocks 转换层 =====
#' 将 gtsummary 表格转为 report_blocks
gtsummary_to_blocks <- function(tbl, group_var, groups, analyze_vars, significant_vars) {
blocks <- list()
# 提取 tibble 格式
tbl_data <- gtsummary::as_tibble(tbl, col_labels = FALSE)
# Block 1: 三线表(核心输出)
headers <- colnames(tbl_data)
rows <- lapply(seq_len(nrow(tbl_data)), function(i) {
row <- as.list(tbl_data[i, ])
lapply(row, function(cell) {
val <- as.character(cell)
if (is.na(val) || val == "NA") "" else val
})
})
# 标记 P < 0.05 的行
p_col_idx <- which(grepl("p.value|p_value", headers, ignore.case = TRUE))
blocks[[length(blocks) + 1]] <- make_table_block(
headers, rows,
title = glue("基线特征表 (按 {group_var} 分组)"),
footnote = "连续变量: Mean (SD); 分类变量: N (%); P 值由自动选择的统计方法计算",
metadata = list(
is_baseline_table = TRUE,
group_var = group_var,
has_p_values = length(p_col_idx) > 0
)
)
# Block 2: 样本量概况
group_n_items <- lapply(groups, function(g) {
list(key = as.character(g), value = "—")
})
blocks[[length(blocks) + 1]] <- make_kv_block(
list("总样本量" = as.character(nrow(tbl$inputs$data)),
"分组变量" = group_var,
"分组数" = as.character(length(groups)),
"分析变量数" = as.character(length(analyze_vars))),
title = "样本概况"
)
# Block 3: 显著变量摘要
if (length(significant_vars) > 0) {
conclusion <- glue("α = 0.05 水平下,以下变量在组间存在显著差异:**{paste(significant_vars, collapse = '**、**')}**(共 {length(significant_vars)} 个)。")
} else {
conclusion <- "α = 0.05 水平下,未发现各组间存在显著差异的基线变量。"
}
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "组间差异摘要")
return(blocks)
}
#' 从 gtsummary 提取显著变量
extract_significant_vars <- function(tbl, alpha = 0.05) {
body <- tbl$table_body
p_vals <- as.numeric(unlist(body$p.value))
vars <- as.character(body$variable)
sig_idx <- which(!is.na(p_vals) & p_vals < alpha)
if (length(sig_idx) == 0) return(character(0))
unique(vars[sig_idx])
}
#' 提取每个变量使用的统计方法
extract_method_info <- function(tbl) {
body <- tbl$table_body
p_vals <- as.numeric(unlist(body$p.value))
has_p <- which(!is.na(p_vals))
if (length(has_p) == 0) return(list())
test_names <- if ("test_name" %in% colnames(body)) as.character(unlist(body$test_name)) else rep("unknown", nrow(body))
lapply(has_p, function(i) {
list(
variable = as.character(body$variable[i]),
test_name = test_names[i] %||% "unknown",
p_value = round(p_vals[i], 4),
p_value_fmt = format_p_value(p_vals[i])
)
})
}
# NULL 合并运算符
`%||%` <- function(x, y) if (is.null(x)) y else x