Files
HaHafeng 3446909ff7 feat(ssa): Complete Phase I-IV intelligent dialogue and tool system development
Phase I - Session Blackboard + READ Layer:
- SessionBlackboardService with Postgres-Only cache
- DataProfileService for data overview generation
- PicoInferenceService for LLM-driven PICO extraction
- Frontend DataContextCard and VariableDictionaryPanel
- E2E tests: 31/31 passed

Phase II - Conversation Layer LLM + Intent Router:
- ConversationService with SSE streaming
- IntentRouterService (rule-first + LLM fallback, 6 intents)
- SystemPromptService with 6-segment dynamic assembly
- TokenTruncationService for context management
- ChatHandlerService as unified chat entry
- Frontend SSAChatPane and useSSAChat hook
- E2E tests: 38/38 passed

Phase III - Method Consultation + AskUser Standardization:
- ToolRegistryService with Repository Pattern
- MethodConsultService with DecisionTable + LLM enhancement
- AskUserService with global interrupt handling
- Frontend AskUserCard component
- E2E tests: 13/13 passed

Phase IV - Dialogue-Driven Analysis + QPER Integration:
- ToolOrchestratorService (plan/execute/report)
- analysis_plan SSE event for WorkflowPlan transmission
- Dual-channel confirmation (ask_user card + workspace button)
- PICO as optional hint for LLM parsing
- E2E tests: 25/25 passed

R Statistics Service:
- 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon
- Enhanced guardrails and block helpers
- Comprehensive test suite (run_all_tools_test.js)

Documentation:
- Updated system status document (v5.9)
- Updated SSA module status and development plan (v1.8)

Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-22 18:53:39 +08:00

378 lines
12 KiB
R
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#' @tool_code ST_LINEAR_REG
#' @name 线性回归
#' @version 1.0.0
#' @description 连续型结局变量的多因素线性回归分析
#' @author SSA-Pro Team
library(glue)
library(ggplot2)
library(base64enc)
run_analysis <- function(input) {
# ===== 初始化 =====
logs <- c()
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
on.exit({}, add = TRUE)
# ===== 数据加载 =====
log_add("开始加载输入数据")
df <- tryCatch(
load_input_data(input),
error = function(e) {
log_add(paste("数据加载失败:", e$message))
return(NULL)
}
)
if (is.null(df)) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
}
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
p <- input$params
outcome_var <- p$outcome_var
predictors <- p$predictors
confounders <- p$confounders
# ===== 参数校验 =====
if (!(outcome_var %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = outcome_var))
}
all_vars <- c(predictors, confounders)
all_vars <- all_vars[!is.null(all_vars) & all_vars != ""]
for (v in all_vars) {
if (!(v %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = v))
}
}
if (length(predictors) == 0) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "至少需要一个预测变量"))
}
# ===== 数据清洗 =====
original_rows <- nrow(df)
vars_to_check <- c(outcome_var, all_vars)
for (v in vars_to_check) {
df <- df[!is.na(df[[v]]), ]
}
# 确保结局变量为数值型
if (!is.numeric(df[[outcome_var]])) {
df[[outcome_var]] <- as.numeric(as.character(df[[outcome_var]]))
df <- df[!is.na(df[[outcome_var]]), ]
}
removed_rows <- original_rows - nrow(df)
if (removed_rows > 0) {
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
}
n_predictors <- length(all_vars)
# ===== 护栏检查 =====
guardrail_results <- list()
warnings_list <- c()
sample_check <- check_sample_size(nrow(df), min_required = n_predictors + 10, action = ACTION_BLOCK)
guardrail_results <- c(guardrail_results, list(sample_check))
log_add(glue("样本量: N = {nrow(df)}, 预测变量数 = {n_predictors}, {sample_check$reason}"))
guardrail_status <- run_guardrail_chain(guardrail_results)
if (guardrail_status$status == "blocked") {
return(list(status = "blocked", message = guardrail_status$reason, trace_log = logs))
}
# ===== 构建模型公式 =====
formula_str <- paste(outcome_var, "~", paste(all_vars, collapse = " + "))
formula_obj <- as.formula(formula_str)
log_add(glue("模型公式: {formula_str}"))
# ===== 核心计算 =====
log_add("拟合线性回归模型")
model <- tryCatch({
lm(formula_obj, data = df)
}, error = function(e) {
log_add(paste("模型拟合失败:", e$message))
return(NULL)
}, warning = function(w) {
warnings_list <<- c(warnings_list, w$message)
log_add(paste("模型警告:", w$message))
invokeRestart("muffleWarning")
})
if (is.null(model)) {
return(map_r_error("线性回归模型拟合失败"))
}
model_summary <- summary(model)
# ===== 提取模型结果 =====
coef_summary <- model_summary$coefficients
coef_table <- data.frame(
variable = rownames(coef_summary),
estimate = coef_summary[, "Estimate"],
std_error = coef_summary[, "Std. Error"],
t_value = coef_summary[, "t value"],
p_value = coef_summary[, "Pr(>|t|)"],
stringsAsFactors = FALSE
)
# 95% 置信区间
ci <- confint(model)
coef_table$ci_lower <- ci[, 1]
coef_table$ci_upper <- ci[, 2]
coefficients_list <- lapply(1:nrow(coef_table), function(i) {
row <- coef_table[i, ]
list(
variable = row$variable,
estimate = round(row$estimate, 4),
std_error = round(row$std_error, 4),
t_value = round(row$t_value, 3),
ci_lower = round(row$ci_lower, 4),
ci_upper = round(row$ci_upper, 4),
p_value = round(row$p_value, 4),
p_value_fmt = format_p_value(row$p_value),
significant = row$p_value < 0.05
)
})
# ===== 模型拟合度 =====
r_squared <- model_summary$r.squared
adj_r_squared <- model_summary$adj.r.squared
f_stat <- model_summary$fstatistic
f_p_value <- if (!is.null(f_stat)) {
pf(f_stat[1], f_stat[2], f_stat[3], lower.tail = FALSE)
} else {
NA
}
if (!is.null(f_stat)) {
log_add(glue("R² = {round(r_squared, 4)}, Adj R² = {round(adj_r_squared, 4)}, F = {round(f_stat[1], 2)}, P = {round(f_p_value, 4)}"))
} else {
log_add(glue("R² = {round(r_squared, 4)}, Adj R² = {round(adj_r_squared, 4)}, F = NA"))
}
# ===== 残差诊断 =====
residuals_vals <- residuals(model)
fitted_vals <- fitted(model)
# 残差正态性
normality_p <- NA
if (length(residuals_vals) >= 3 && length(residuals_vals) <= 5000) {
normality_test <- shapiro.test(residuals_vals)
normality_p <- normality_test$p.value
if (normality_p < 0.05) {
warnings_list <- c(warnings_list, glue("残差不满足正态性 (Shapiro-Wilk p = {round(normality_p, 4)})"))
}
}
# ===== 共线性检测 (VIF) =====
vif_results <- NULL
if (length(all_vars) > 1) {
tryCatch({
if (requireNamespace("car", quietly = TRUE)) {
vif_values <- car::vif(model)
if (is.matrix(vif_values)) {
vif_values <- vif_values[, "GVIF"]
}
vif_results <- lapply(names(vif_values), function(v) {
list(variable = v, vif = round(vif_values[v], 2))
})
high_vif <- names(vif_values)[vif_values > 5]
if (length(high_vif) > 0) {
warnings_list <- c(warnings_list, paste("VIF > 5 的变量:", paste(high_vif, collapse = ", ")))
}
}
}, error = function(e) {
log_add(paste("VIF 计算失败:", e$message))
})
}
# ===== 生成图表 =====
log_add("生成诊断图")
plot_base64 <- tryCatch({
generate_regression_plots(model, outcome_var)
}, error = function(e) {
log_add(paste("图表生成失败:", e$message))
NULL
})
# ===== 生成可复现代码 =====
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
input$original_filename
} else {
"data.csv"
}
reproducible_code <- glue('
# SSA-Pro 自动生成代码
# 工具: 线性回归
# 时间: {Sys.time()}
# ================================
# 数据准备
df <- read.csv("{original_filename}")
# 线性回归
model <- lm({formula_str}, data = df)
summary(model)
# 置信区间
confint(model)
# 残差诊断
par(mfrow = c(2, 2))
plot(model)
# VIF需要 car 包)
# library(car)
# vif(model)
')
# ===== 构建 report_blocks =====
blocks <- list()
# Block 1: 模型概况
kv_model <- list(
"模型公式" = formula_str,
"观测数" = as.character(nrow(df)),
"预测变量数" = as.character(n_predictors),
"R²" = as.character(round(r_squared, 4)),
"调整 R²" = as.character(round(adj_r_squared, 4))
)
if (!is.null(f_stat)) {
kv_model[["F 统计量"]] <- as.character(round(f_stat[1], 2))
kv_model[["模型 P 值"]] <- format_p_value(f_p_value)
}
if (!is.na(normality_p)) {
kv_model[["残差正态性 (Shapiro P)"]] <- format_p_value(normality_p)
}
blocks[[length(blocks) + 1]] <- make_kv_block(kv_model, title = "模型概况")
# Block 2: 回归系数表
coef_headers <- c("变量", "系数 (B)", "标准误", "t 值", "95% CI", "P 值", "显著性")
coef_rows <- lapply(coefficients_list, function(row) {
ci_str <- sprintf("[%.4f, %.4f]", row$ci_lower, row$ci_upper)
sig <- if (row$significant) "*" else ""
c(row$variable, as.character(row$estimate), as.character(row$std_error),
as.character(row$t_value), ci_str, row$p_value_fmt, sig)
})
blocks[[length(blocks) + 1]] <- make_table_block(coef_headers, coef_rows,
title = "回归系数表", footnote = "* P < 0.05")
# Block 3: VIF 表
if (!is.null(vif_results) && length(vif_results) > 0) {
vif_headers <- c("变量", "VIF")
vif_rows <- lapply(vif_results, function(row) c(row$variable, as.character(row$vif)))
blocks[[length(blocks) + 1]] <- make_table_block(vif_headers, vif_rows, title = "方差膨胀因子 (VIF)")
}
# Block 4: 诊断图
if (!is.null(plot_base64)) {
blocks[[length(blocks) + 1]] <- make_image_block(plot_base64,
title = "回归诊断图", alt = "残差 vs 拟合值 + Q-Q 图")
}
# Block 5: 结论摘要
sig_vars <- sapply(coefficients_list, function(r) {
if (r$variable != "(Intercept)" && r$significant) r$variable else NULL
})
sig_vars <- unlist(sig_vars[!sapply(sig_vars, is.null)])
model_sig <- if (!is.na(f_p_value) && f_p_value < 0.05) "整体具有统计学意义" else "整体不具有统计学意义"
f_display <- if (!is.null(f_stat)) round(f_stat[1], 2) else "NA"
p_display <- if (!is.na(f_p_value)) format_p_value(f_p_value) else "NA"
conclusion <- glue("线性回归模型{model_sig}F = {f_display}P {p_display})。模型解释了因变量 {round(r_squared * 100, 1)}% 的变异R² = {round(r_squared, 4)},调整 R² = {round(adj_r_squared, 4)})。")
if (length(sig_vars) > 0) {
conclusion <- paste0(conclusion, glue("\n\n在 α = 0.05 水平下,以下预测变量具有统计学意义:**{paste(sig_vars, collapse = '**、**')}**。"))
} else {
conclusion <- paste0(conclusion, "\n\n在 α = 0.05 水平下,无预测变量达到统计学意义。")
}
if (length(warnings_list) > 0) {
conclusion <- paste0(conclusion, "\n\n⚠ 注意:", paste(warnings_list, collapse = ""), "。")
}
blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "结论摘要")
# ===== 返回结果 =====
log_add("分析完成")
return(list(
status = "success",
message = "分析完成",
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
results = list(
method = "Multiple Linear Regression (OLS)",
formula = formula_str,
n_observations = nrow(df),
n_predictors = n_predictors,
coefficients = coefficients_list,
model_fit = list(
r_squared = jsonlite::unbox(round(r_squared, 4)),
adj_r_squared = jsonlite::unbox(round(adj_r_squared, 4)),
f_statistic = if (!is.null(f_stat)) jsonlite::unbox(round(f_stat[1], 2)) else NULL,
f_df = if (!is.null(f_stat)) as.numeric(f_stat[2:3]) else NULL,
f_p_value = if (!is.na(f_p_value)) jsonlite::unbox(round(f_p_value, 4)) else NULL,
f_p_value_fmt = if (!is.na(f_p_value)) format_p_value(f_p_value) else NULL
),
diagnostics = list(
residual_normality_p = if (!is.na(normality_p)) jsonlite::unbox(round(normality_p, 4)) else NULL
),
vif = vif_results
),
report_blocks = blocks,
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
trace_log = logs,
reproducible_code = as.character(reproducible_code)
))
}
# 辅助函数:回归诊断图(残差 vs 拟合值 + Q-Q 图 拼接)
generate_regression_plots <- function(model, outcome_var) {
diag_df <- data.frame(
fitted = fitted(model),
residuals = residuals(model),
std_residuals = rstandard(model)
)
# 残差 vs 拟合值
p1 <- ggplot(diag_df, aes(x = fitted, y = residuals)) +
geom_point(alpha = 0.5, color = "#3b82f6") +
geom_hline(yintercept = 0, linetype = "dashed", color = "red") +
geom_smooth(method = "loess", se = FALSE, color = "orange", linewidth = 0.8) +
theme_minimal() +
labs(title = "Residuals vs Fitted", x = "Fitted values", y = "Residuals")
# Q-Q 图
p2 <- ggplot(diag_df, aes(sample = std_residuals)) +
stat_qq(alpha = 0.5, color = "#3b82f6") +
stat_qq_line(color = "red", linetype = "dashed") +
theme_minimal() +
labs(title = "Normal Q-Q Plot", x = "Theoretical Quantiles", y = "Standardized Residuals")
# 拼图
if (requireNamespace("gridExtra", quietly = TRUE)) {
combined <- gridExtra::arrangeGrob(p1, p2, ncol = 2)
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, combined, width = 12, height = 5, dpi = 100)
} else {
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, p1, width = 7, height = 5, dpi = 100)
}
base64_str <- base64encode(tmp_file)
unlink(tmp_file)
return(paste0("data:image/png;base64,", base64_str))
}