feat(ssa): Complete Phase I-IV intelligent dialogue and tool system development

Phase I - Session Blackboard + READ Layer:
- SessionBlackboardService with Postgres-Only cache
- DataProfileService for data overview generation
- PicoInferenceService for LLM-driven PICO extraction
- Frontend DataContextCard and VariableDictionaryPanel
- E2E tests: 31/31 passed

Phase II - Conversation Layer LLM + Intent Router:
- ConversationService with SSE streaming
- IntentRouterService (rule-first + LLM fallback, 6 intents)
- SystemPromptService with 6-segment dynamic assembly
- TokenTruncationService for context management
- ChatHandlerService as unified chat entry
- Frontend SSAChatPane and useSSAChat hook
- E2E tests: 38/38 passed

Phase III - Method Consultation + AskUser Standardization:
- ToolRegistryService with Repository Pattern
- MethodConsultService with DecisionTable + LLM enhancement
- AskUserService with global interrupt handling
- Frontend AskUserCard component
- E2E tests: 13/13 passed

Phase IV - Dialogue-Driven Analysis + QPER Integration:
- ToolOrchestratorService (plan/execute/report)
- analysis_plan SSE event for WorkflowPlan transmission
- Dual-channel confirmation (ask_user card + workspace button)
- PICO as optional hint for LLM parsing
- E2E tests: 25/25 passed

R Statistics Service:
- 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon
- Enhanced guardrails and block helpers
- Comprehensive test suite (run_all_tools_test.js)

Documentation:
- Updated system status document (v5.9)
- Updated SSA module status and development plan (v1.8)

Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25)

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
2026-02-22 18:53:39 +08:00
parent bf10dec4c8
commit 3446909ff7
68 changed files with 11583 additions and 412 deletions

View File

@@ -0,0 +1,377 @@
#' @tool_code ST_LINEAR_REG
#' @name 线性回归
#' @version 1.0.0
#' @description 连续型结局变量的多因素线性回归分析
#' @author SSA-Pro Team
library(glue)
library(ggplot2)
library(base64enc)
run_analysis <- function(input) {
# ===== 初始化 =====
logs <- c()
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
on.exit({}, add = TRUE)
# ===== 数据加载 =====
log_add("开始加载输入数据")
df <- tryCatch(
load_input_data(input),
error = function(e) {
log_add(paste("数据加载失败:", e$message))
return(NULL)
}
)
if (is.null(df)) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
}
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
p <- input$params
outcome_var <- p$outcome_var
predictors <- p$predictors
confounders <- p$confounders
# ===== 参数校验 =====
if (!(outcome_var %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = outcome_var))
}
all_vars <- c(predictors, confounders)
all_vars <- all_vars[!is.null(all_vars) & all_vars != ""]
for (v in all_vars) {
if (!(v %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = v))
}
}
if (length(predictors) == 0) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "至少需要一个预测变量"))
}
# ===== 数据清洗 =====
original_rows <- nrow(df)
vars_to_check <- c(outcome_var, all_vars)
for (v in vars_to_check) {
df <- df[!is.na(df[[v]]), ]
}
# 确保结局变量为数值型
if (!is.numeric(df[[outcome_var]])) {
df[[outcome_var]] <- as.numeric(as.character(df[[outcome_var]]))
df <- df[!is.na(df[[outcome_var]]), ]
}
removed_rows <- original_rows - nrow(df)
if (removed_rows > 0) {
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
}
n_predictors <- length(all_vars)
# ===== 护栏检查 =====
guardrail_results <- list()
warnings_list <- c()
sample_check <- check_sample_size(nrow(df), min_required = n_predictors + 10, action = ACTION_BLOCK)
guardrail_results <- c(guardrail_results, list(sample_check))
log_add(glue("样本量: N = {nrow(df)}, 预测变量数 = {n_predictors}, {sample_check$reason}"))
guardrail_status <- run_guardrail_chain(guardrail_results)
if (guardrail_status$status == "blocked") {
return(list(status = "blocked", message = guardrail_status$reason, trace_log = logs))
}
# ===== 构建模型公式 =====
formula_str <- paste(outcome_var, "~", paste(all_vars, collapse = " + "))
formula_obj <- as.formula(formula_str)
log_add(glue("模型公式: {formula_str}"))
# ===== 核心计算 =====
log_add("拟合线性回归模型")
model <- tryCatch({
lm(formula_obj, data = df)
}, error = function(e) {
log_add(paste("模型拟合失败:", e$message))
return(NULL)
}, warning = function(w) {
warnings_list <<- c(warnings_list, w$message)
log_add(paste("模型警告:", w$message))
invokeRestart("muffleWarning")
})
if (is.null(model)) {
return(map_r_error("线性回归模型拟合失败"))
}
model_summary <- summary(model)
# ===== 提取模型结果 =====
coef_summary <- model_summary$coefficients
coef_table <- data.frame(
variable = rownames(coef_summary),
estimate = coef_summary[, "Estimate"],
std_error = coef_summary[, "Std. Error"],
t_value = coef_summary[, "t value"],
p_value = coef_summary[, "Pr(>|t|)"],
stringsAsFactors = FALSE
)
# 95% 置信区间
ci <- confint(model)
coef_table$ci_lower <- ci[, 1]
coef_table$ci_upper <- ci[, 2]
coefficients_list <- lapply(1:nrow(coef_table), function(i) {
row <- coef_table[i, ]
list(
variable = row$variable,
estimate = round(row$estimate, 4),
std_error = round(row$std_error, 4),
t_value = round(row$t_value, 3),
ci_lower = round(row$ci_lower, 4),
ci_upper = round(row$ci_upper, 4),
p_value = round(row$p_value, 4),
p_value_fmt = format_p_value(row$p_value),
significant = row$p_value < 0.05
)
})
# ===== 模型拟合度 =====
r_squared <- model_summary$r.squared
adj_r_squared <- model_summary$adj.r.squared
f_stat <- model_summary$fstatistic
f_p_value <- if (!is.null(f_stat)) {
pf(f_stat[1], f_stat[2], f_stat[3], lower.tail = FALSE)
} else {
NA
}
if (!is.null(f_stat)) {
log_add(glue("R² = {round(r_squared, 4)}, Adj R² = {round(adj_r_squared, 4)}, F = {round(f_stat[1], 2)}, P = {round(f_p_value, 4)}"))
} else {
log_add(glue("R² = {round(r_squared, 4)}, Adj R² = {round(adj_r_squared, 4)}, F = NA"))
}
# ===== 残差诊断 =====
residuals_vals <- residuals(model)
fitted_vals <- fitted(model)
# 残差正态性
normality_p <- NA
if (length(residuals_vals) >= 3 && length(residuals_vals) <= 5000) {
normality_test <- shapiro.test(residuals_vals)
normality_p <- normality_test$p.value
if (normality_p < 0.05) {
warnings_list <- c(warnings_list, glue("残差不满足正态性 (Shapiro-Wilk p = {round(normality_p, 4)})"))
}
}
# ===== 共线性检测 (VIF) =====
vif_results <- NULL
if (length(all_vars) > 1) {
tryCatch({
if (requireNamespace("car", quietly = TRUE)) {
vif_values <- car::vif(model)
if (is.matrix(vif_values)) {
vif_values <- vif_values[, "GVIF"]
}
vif_results <- lapply(names(vif_values), function(v) {
list(variable = v, vif = round(vif_values[v], 2))
})
high_vif <- names(vif_values)[vif_values > 5]
if (length(high_vif) > 0) {
warnings_list <- c(warnings_list, paste("VIF > 5 的变量:", paste(high_vif, collapse = ", ")))
}
}
}, error = function(e) {
log_add(paste("VIF 计算失败:", e$message))
})
}
# ===== 生成图表 =====
log_add("生成诊断图")
plot_base64 <- tryCatch({
generate_regression_plots(model, outcome_var)
}, error = function(e) {
log_add(paste("图表生成失败:", e$message))
NULL
})
# ===== 生成可复现代码 =====
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
input$original_filename
} else {
"data.csv"
}
reproducible_code <- glue('
# SSA-Pro 自动生成代码
# 工具: 线性回归
# 时间: {Sys.time()}
# ================================
# 数据准备
df <- read.csv("{original_filename}")
# 线性回归
model <- lm({formula_str}, data = df)
summary(model)
# 置信区间
confint(model)
# 残差诊断
par(mfrow = c(2, 2))
plot(model)
# VIF需要 car 包)
# library(car)
# vif(model)
')
# ===== 构建 report_blocks =====
blocks <- list()
# Block 1: 模型概况
kv_model <- list(
"模型公式" = formula_str,
"观测数" = as.character(nrow(df)),
"预测变量数" = as.character(n_predictors),
"R²" = as.character(round(r_squared, 4)),
"调整 R²" = as.character(round(adj_r_squared, 4))
)
if (!is.null(f_stat)) {
kv_model[["F 统计量"]] <- as.character(round(f_stat[1], 2))
kv_model[["模型 P 值"]] <- format_p_value(f_p_value)
}
if (!is.na(normality_p)) {
kv_model[["残差正态性 (Shapiro P)"]] <- format_p_value(normality_p)
}
blocks[[length(blocks) + 1]] <- make_kv_block(kv_model, title = "模型概况")
# Block 2: 回归系数表
coef_headers <- c("变量", "系数 (B)", "标准误", "t 值", "95% CI", "P 值", "显著性")
coef_rows <- lapply(coefficients_list, function(row) {
ci_str <- sprintf("[%.4f, %.4f]", row$ci_lower, row$ci_upper)
sig <- if (row$significant) "*" else ""
c(row$variable, as.character(row$estimate), as.character(row$std_error),
as.character(row$t_value), ci_str, row$p_value_fmt, sig)
})
blocks[[length(blocks) + 1]] <- make_table_block(coef_headers, coef_rows,
title = "回归系数表", footnote = "* P < 0.05")
# Block 3: VIF 表
if (!is.null(vif_results) && length(vif_results) > 0) {
vif_headers <- c("变量", "VIF")
vif_rows <- lapply(vif_results, function(row) c(row$variable, as.character(row$vif)))
blocks[[length(blocks) + 1]] <- make_table_block(vif_headers, vif_rows, title = "方差膨胀因子 (VIF)")
}
# Block 4: 诊断图
if (!is.null(plot_base64)) {
blocks[[length(blocks) + 1]] <- make_image_block(plot_base64,
title = "回归诊断图", alt = "残差 vs 拟合值 + Q-Q 图")
}
# Block 5: 结论摘要
sig_vars <- sapply(coefficients_list, function(r) {
if (r$variable != "(Intercept)" && r$significant) r$variable else NULL
})
sig_vars <- unlist(sig_vars[!sapply(sig_vars, is.null)])
model_sig <- if (!is.na(f_p_value) && f_p_value < 0.05) "整体具有统计学意义" else "整体不具有统计学意义"
f_display <- if (!is.null(f_stat)) round(f_stat[1], 2) else "NA"
p_display <- if (!is.na(f_p_value)) format_p_value(f_p_value) else "NA"
conclusion <- glue("线性回归模型{model_sig}F = {f_display}P {p_display})。模型解释了因变量 {round(r_squared * 100, 1)}% 的变异R² = {round(r_squared, 4)},调整 R² = {round(adj_r_squared, 4)})。")
if (length(sig_vars) > 0) {
conclusion <- paste0(conclusion, glue("\n\n在 α = 0.05 水平下,以下预测变量具有统计学意义:**{paste(sig_vars, collapse = '**、**')}**。"))
} else {
conclusion <- paste0(conclusion, "\n\n在 α = 0.05 水平下,无预测变量达到统计学意义。")
}
if (length(warnings_list) > 0) {
conclusion <- paste0(conclusion, "\n\n⚠ 注意:", paste(warnings_list, collapse = ""), "。")
}
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "结论摘要")
# ===== 返回结果 =====
log_add("分析完成")
return(list(
status = "success",
message = "分析完成",
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
results = list(
method = "Multiple Linear Regression (OLS)",
formula = formula_str,
n_observations = nrow(df),
n_predictors = n_predictors,
coefficients = coefficients_list,
model_fit = list(
r_squared = jsonlite::unbox(round(r_squared, 4)),
adj_r_squared = jsonlite::unbox(round(adj_r_squared, 4)),
f_statistic = if (!is.null(f_stat)) jsonlite::unbox(round(f_stat[1], 2)) else NULL,
f_df = if (!is.null(f_stat)) as.numeric(f_stat[2:3]) else NULL,
f_p_value = if (!is.na(f_p_value)) jsonlite::unbox(round(f_p_value, 4)) else NULL,
f_p_value_fmt = if (!is.na(f_p_value)) format_p_value(f_p_value) else NULL
),
diagnostics = list(
residual_normality_p = if (!is.na(normality_p)) jsonlite::unbox(round(normality_p, 4)) else NULL
),
vif = vif_results
),
report_blocks = blocks,
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
trace_log = logs,
reproducible_code = as.character(reproducible_code)
))
}
# 辅助函数:回归诊断图(残差 vs 拟合值 + Q-Q 图 拼接)
generate_regression_plots <- function(model, outcome_var) {
diag_df <- data.frame(
fitted = fitted(model),
residuals = residuals(model),
std_residuals = rstandard(model)
)
# 残差 vs 拟合值
p1 <- ggplot(diag_df, aes(x = fitted, y = residuals)) +
geom_point(alpha = 0.5, color = "#3b82f6") +
geom_hline(yintercept = 0, linetype = "dashed", color = "red") +
geom_smooth(method = "loess", se = FALSE, color = "orange", linewidth = 0.8) +
theme_minimal() +
labs(title = "Residuals vs Fitted", x = "Fitted values", y = "Residuals")
# Q-Q 图
p2 <- ggplot(diag_df, aes(sample = std_residuals)) +
stat_qq(alpha = 0.5, color = "#3b82f6") +
stat_qq_line(color = "red", linetype = "dashed") +
theme_minimal() +
labs(title = "Normal Q-Q Plot", x = "Theoretical Quantiles", y = "Standardized Residuals")
# 拼图
if (requireNamespace("gridExtra", quietly = TRUE)) {
combined <- gridExtra::arrangeGrob(p1, p2, ncol = 2)
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, combined, width = 12, height = 5, dpi = 100)
} else {
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, p1, width = 7, height = 5, dpi = 100)
}
base64_str <- base64encode(tmp_file)
unlink(tmp_file)
return(paste0("data:image/png;base64,", base64_str))
}