#' @tool_code ST_LINEAR_REG #' @name 线性回归 #' @version 1.0.0 #' @description 连续型结局变量的多因素线性回归分析 #' @author SSA-Pro Team library(glue) library(ggplot2) library(base64enc) run_analysis <- function(input) { # ===== 初始化 ===== logs <- c() log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) } on.exit({}, add = TRUE) # ===== 数据加载 ===== log_add("开始加载输入数据") df <- tryCatch( load_input_data(input), error = function(e) { log_add(paste("数据加载失败:", e$message)) return(NULL) } ) if (is.null(df)) { return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败")) } log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列")) p <- input$params outcome_var <- p$outcome_var predictors <- p$predictors confounders <- p$confounders # ===== 参数校验 ===== if (!(outcome_var %in% names(df))) { return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = outcome_var)) } all_vars <- c(predictors, confounders) all_vars <- all_vars[!is.null(all_vars) & all_vars != ""] for (v in all_vars) { if (!(v %in% names(df))) { return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = v)) } } if (length(predictors) == 0) { return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "至少需要一个预测变量")) } # ===== 数据清洗 ===== original_rows <- nrow(df) vars_to_check <- c(outcome_var, all_vars) for (v in vars_to_check) { df <- df[!is.na(df[[v]]), ] } # 确保结局变量为数值型 if (!is.numeric(df[[outcome_var]])) { df[[outcome_var]] <- as.numeric(as.character(df[[outcome_var]])) df <- df[!is.na(df[[outcome_var]]), ] } removed_rows <- original_rows - nrow(df) if (removed_rows > 0) { log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)")) } n_predictors <- length(all_vars) # ===== 护栏检查 ===== guardrail_results <- list() warnings_list <- c() sample_check <- check_sample_size(nrow(df), min_required = n_predictors + 10, action = ACTION_BLOCK) guardrail_results <- c(guardrail_results, list(sample_check)) log_add(glue("样本量: N = {nrow(df)}, 预测变量数 = {n_predictors}, {sample_check$reason}")) guardrail_status <- run_guardrail_chain(guardrail_results) if (guardrail_status$status == "blocked") { return(list(status = "blocked", message = guardrail_status$reason, trace_log = logs)) } # ===== 构建模型公式 ===== formula_str <- paste(outcome_var, "~", paste(all_vars, collapse = " + ")) formula_obj <- as.formula(formula_str) log_add(glue("模型公式: {formula_str}")) # ===== 核心计算 ===== log_add("拟合线性回归模型") model <- tryCatch({ lm(formula_obj, data = df) }, error = function(e) { log_add(paste("模型拟合失败:", e$message)) return(NULL) }, warning = function(w) { warnings_list <<- c(warnings_list, w$message) log_add(paste("模型警告:", w$message)) invokeRestart("muffleWarning") }) if (is.null(model)) { return(map_r_error("线性回归模型拟合失败")) } model_summary <- summary(model) # ===== 提取模型结果 ===== coef_summary <- model_summary$coefficients coef_table <- data.frame( variable = rownames(coef_summary), estimate = coef_summary[, "Estimate"], std_error = coef_summary[, "Std. Error"], t_value = coef_summary[, "t value"], p_value = coef_summary[, "Pr(>|t|)"], stringsAsFactors = FALSE ) # 95% 置信区间 ci <- confint(model) coef_table$ci_lower <- ci[, 1] coef_table$ci_upper <- ci[, 2] coefficients_list <- lapply(1:nrow(coef_table), function(i) { row <- coef_table[i, ] list( variable = row$variable, estimate = round(row$estimate, 4), std_error = round(row$std_error, 4), t_value = round(row$t_value, 3), ci_lower = round(row$ci_lower, 4), ci_upper = round(row$ci_upper, 4), p_value = round(row$p_value, 4), p_value_fmt = format_p_value(row$p_value), significant = row$p_value < 0.05 ) }) # ===== 模型拟合度 ===== r_squared <- model_summary$r.squared adj_r_squared <- model_summary$adj.r.squared f_stat <- model_summary$fstatistic f_p_value <- if (!is.null(f_stat)) { pf(f_stat[1], f_stat[2], f_stat[3], lower.tail = FALSE) } else { NA } if (!is.null(f_stat)) { log_add(glue("R² = {round(r_squared, 4)}, Adj R² = {round(adj_r_squared, 4)}, F = {round(f_stat[1], 2)}, P = {round(f_p_value, 4)}")) } else { log_add(glue("R² = {round(r_squared, 4)}, Adj R² = {round(adj_r_squared, 4)}, F = NA")) } # ===== 残差诊断 ===== residuals_vals <- residuals(model) fitted_vals <- fitted(model) # 残差正态性 normality_p <- NA if (length(residuals_vals) >= 3 && length(residuals_vals) <= 5000) { normality_test <- shapiro.test(residuals_vals) normality_p <- normality_test$p.value if (normality_p < 0.05) { warnings_list <- c(warnings_list, glue("残差不满足正态性 (Shapiro-Wilk p = {round(normality_p, 4)})")) } } # ===== 共线性检测 (VIF) ===== vif_results <- NULL if (length(all_vars) > 1) { tryCatch({ if (requireNamespace("car", quietly = TRUE)) { vif_values <- car::vif(model) if (is.matrix(vif_values)) { vif_values <- vif_values[, "GVIF"] } vif_results <- lapply(names(vif_values), function(v) { list(variable = v, vif = round(vif_values[v], 2)) }) high_vif <- names(vif_values)[vif_values > 5] if (length(high_vif) > 0) { warnings_list <- c(warnings_list, paste("VIF > 5 的变量:", paste(high_vif, collapse = ", "))) } } }, error = function(e) { log_add(paste("VIF 计算失败:", e$message)) }) } # ===== 生成图表 ===== log_add("生成诊断图") plot_base64 <- tryCatch({ generate_regression_plots(model, outcome_var) }, error = function(e) { log_add(paste("图表生成失败:", e$message)) NULL }) # ===== 生成可复现代码 ===== original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) { input$original_filename } else { "data.csv" } reproducible_code <- glue(' # SSA-Pro 自动生成代码 # 工具: 线性回归 # 时间: {Sys.time()} # ================================ # 数据准备 df <- read.csv("{original_filename}") # 线性回归 model <- lm({formula_str}, data = df) summary(model) # 置信区间 confint(model) # 残差诊断 par(mfrow = c(2, 2)) plot(model) # VIF(需要 car 包) # library(car) # vif(model) ') # ===== 构建 report_blocks ===== blocks <- list() # Block 1: 模型概况 kv_model <- list( "模型公式" = formula_str, "观测数" = as.character(nrow(df)), "预测变量数" = as.character(n_predictors), "R²" = as.character(round(r_squared, 4)), "调整 R²" = as.character(round(adj_r_squared, 4)) ) if (!is.null(f_stat)) { kv_model[["F 统计量"]] <- as.character(round(f_stat[1], 2)) kv_model[["模型 P 值"]] <- format_p_value(f_p_value) } if (!is.na(normality_p)) { kv_model[["残差正态性 (Shapiro P)"]] <- format_p_value(normality_p) } blocks[[length(blocks) + 1]] <- make_kv_block(kv_model, title = "模型概况") # Block 2: 回归系数表 coef_headers <- c("变量", "系数 (B)", "标准误", "t 值", "95% CI", "P 值", "显著性") coef_rows <- lapply(coefficients_list, function(row) { ci_str <- sprintf("[%.4f, %.4f]", row$ci_lower, row$ci_upper) sig <- if (row$significant) "*" else "" c(row$variable, as.character(row$estimate), as.character(row$std_error), as.character(row$t_value), ci_str, row$p_value_fmt, sig) }) blocks[[length(blocks) + 1]] <- make_table_block(coef_headers, coef_rows, title = "回归系数表", footnote = "* P < 0.05") # Block 3: VIF 表 if (!is.null(vif_results) && length(vif_results) > 0) { vif_headers <- c("变量", "VIF") vif_rows <- lapply(vif_results, function(row) c(row$variable, as.character(row$vif))) blocks[[length(blocks) + 1]] <- make_table_block(vif_headers, vif_rows, title = "方差膨胀因子 (VIF)") } # Block 4: 诊断图 if (!is.null(plot_base64)) { blocks[[length(blocks) + 1]] <- make_image_block(plot_base64, title = "回归诊断图", alt = "残差 vs 拟合值 + Q-Q 图") } # Block 5: 结论摘要 sig_vars <- sapply(coefficients_list, function(r) { if (r$variable != "(Intercept)" && r$significant) r$variable else NULL }) sig_vars <- unlist(sig_vars[!sapply(sig_vars, is.null)]) model_sig <- if (!is.na(f_p_value) && f_p_value < 0.05) "整体具有统计学意义" else "整体不具有统计学意义" f_display <- if (!is.null(f_stat)) round(f_stat[1], 2) else "NA" p_display <- if (!is.na(f_p_value)) format_p_value(f_p_value) else "NA" conclusion <- glue("线性回归模型{model_sig}(F = {f_display},P {p_display})。模型解释了因变量 {round(r_squared * 100, 1)}% 的变异(R² = {round(r_squared, 4)},调整 R² = {round(adj_r_squared, 4)})。") if (length(sig_vars) > 0) { conclusion <- paste0(conclusion, glue("\n\n在 α = 0.05 水平下,以下预测变量具有统计学意义:**{paste(sig_vars, collapse = '**、**')}**。")) } else { conclusion <- paste0(conclusion, "\n\n在 α = 0.05 水平下,无预测变量达到统计学意义。") } if (length(warnings_list) > 0) { conclusion <- paste0(conclusion, "\n\n⚠️ 注意:", paste(warnings_list, collapse = ";"), "。") } blocks[[length(blocks) + 1]] <- make_markdown_block(conclusion, title = "结论摘要") # ===== 返回结果 ===== log_add("分析完成") return(list( status = "success", message = "分析完成", warnings = if (length(warnings_list) > 0) warnings_list else NULL, results = list( method = "Multiple Linear Regression (OLS)", formula = formula_str, n_observations = nrow(df), n_predictors = n_predictors, coefficients = coefficients_list, model_fit = list( r_squared = jsonlite::unbox(round(r_squared, 4)), adj_r_squared = jsonlite::unbox(round(adj_r_squared, 4)), f_statistic = if (!is.null(f_stat)) jsonlite::unbox(round(f_stat[1], 2)) else NULL, f_df = if (!is.null(f_stat)) as.numeric(f_stat[2:3]) else NULL, f_p_value = if (!is.na(f_p_value)) jsonlite::unbox(round(f_p_value, 4)) else NULL, f_p_value_fmt = if (!is.na(f_p_value)) format_p_value(f_p_value) else NULL ), diagnostics = list( residual_normality_p = if (!is.na(normality_p)) jsonlite::unbox(round(normality_p, 4)) else NULL ), vif = vif_results ), report_blocks = blocks, plots = if (!is.null(plot_base64)) list(plot_base64) else list(), trace_log = logs, reproducible_code = as.character(reproducible_code) )) } # 辅助函数:回归诊断图(残差 vs 拟合值 + Q-Q 图 拼接) generate_regression_plots <- function(model, outcome_var) { diag_df <- data.frame( fitted = fitted(model), residuals = residuals(model), std_residuals = rstandard(model) ) # 残差 vs 拟合值 p1 <- ggplot(diag_df, aes(x = fitted, y = residuals)) + geom_point(alpha = 0.5, color = "#3b82f6") + geom_hline(yintercept = 0, linetype = "dashed", color = "red") + geom_smooth(method = "loess", se = FALSE, color = "orange", linewidth = 0.8) + theme_minimal() + labs(title = "Residuals vs Fitted", x = "Fitted values", y = "Residuals") # Q-Q 图 p2 <- ggplot(diag_df, aes(sample = std_residuals)) + stat_qq(alpha = 0.5, color = "#3b82f6") + stat_qq_line(color = "red", linetype = "dashed") + theme_minimal() + labs(title = "Normal Q-Q Plot", x = "Theoretical Quantiles", y = "Standardized Residuals") # 拼图 if (requireNamespace("gridExtra", quietly = TRUE)) { combined <- gridExtra::arrangeGrob(p1, p2, ncol = 2) tmp_file <- tempfile(fileext = ".png") ggsave(tmp_file, combined, width = 12, height = 5, dpi = 100) } else { tmp_file <- tempfile(fileext = ".png") ggsave(tmp_file, p1, width = 7, height = 5, dpi = 100) } base64_str <- base64encode(tmp_file) unlink(tmp_file) return(paste0("data:image/png;base64,", base64_str)) }