feat(ssa): Complete Phase 2A frontend integration - multi-step workflow end-to-end
Phase 2A: WorkflowPlannerService, WorkflowExecutorService, Python data quality, 6 bug fixes, DescriptiveResultView, multi-step R code/Word export, MVP UI reuse. V11 UI: Gemini-style, multi-task, single-page scroll, Word export. Architecture: Block-based rendering consensus (4 block types). New R tools: chi_square, correlation, descriptive, logistic_binary, mann_whitney, t_test_paired. Docs: dev summary, block-based plan, status updates, task list v2.0. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
316
r-statistics-service/tools/logistic_binary.R
Normal file
316
r-statistics-service/tools/logistic_binary.R
Normal file
@@ -0,0 +1,316 @@
|
||||
#' @tool_code ST_LOGISTIC_BINARY
|
||||
#' @name 二元 Logistic 回归
|
||||
#' @version 1.0.0
|
||||
#' @description 二分类结局变量的多因素分析
|
||||
#' @author SSA-Pro Team
|
||||
|
||||
library(glue)
|
||||
library(ggplot2)
|
||||
library(base64enc)
|
||||
|
||||
run_analysis <- function(input) {
|
||||
# ===== 初始化 =====
|
||||
logs <- c()
|
||||
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
|
||||
|
||||
on.exit({}, add = TRUE)
|
||||
|
||||
# ===== 数据加载 =====
|
||||
log_add("开始加载输入数据")
|
||||
df <- tryCatch(
|
||||
load_input_data(input),
|
||||
error = function(e) {
|
||||
log_add(paste("数据加载失败:", e$message))
|
||||
return(NULL)
|
||||
}
|
||||
)
|
||||
|
||||
if (is.null(df)) {
|
||||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
|
||||
}
|
||||
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
|
||||
|
||||
p <- input$params
|
||||
outcome_var <- p$outcome_var
|
||||
predictors <- p$predictors # 预测变量列表
|
||||
confounders <- p$confounders # 混杂因素(可选)
|
||||
|
||||
# ===== 参数校验 =====
|
||||
if (!(outcome_var %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = outcome_var))
|
||||
}
|
||||
|
||||
all_vars <- c(predictors, confounders)
|
||||
all_vars <- all_vars[!is.null(all_vars) & all_vars != ""]
|
||||
|
||||
for (v in all_vars) {
|
||||
if (!(v %in% names(df))) {
|
||||
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = v))
|
||||
}
|
||||
}
|
||||
|
||||
if (length(predictors) == 0) {
|
||||
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "至少需要一个预测变量"))
|
||||
}
|
||||
|
||||
# ===== 数据清洗 =====
|
||||
original_rows <- nrow(df)
|
||||
|
||||
# 移除所有相关变量的缺失值
|
||||
vars_to_check <- c(outcome_var, all_vars)
|
||||
for (v in vars_to_check) {
|
||||
df <- df[!is.na(df[[v]]), ]
|
||||
}
|
||||
|
||||
removed_rows <- original_rows - nrow(df)
|
||||
if (removed_rows > 0) {
|
||||
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
|
||||
}
|
||||
|
||||
# ===== 结局变量检查 =====
|
||||
outcome_values <- unique(df[[outcome_var]])
|
||||
if (length(outcome_values) != 2) {
|
||||
return(make_error(ERROR_CODES$E003_INSUFFICIENT_GROUPS,
|
||||
col = outcome_var, expected = 2, actual = length(outcome_values)))
|
||||
}
|
||||
|
||||
# 确保结局变量是 0/1 或因子
|
||||
if (!is.factor(df[[outcome_var]])) {
|
||||
df[[outcome_var]] <- as.factor(df[[outcome_var]])
|
||||
}
|
||||
|
||||
# 事件数统计
|
||||
event_counts <- table(df[[outcome_var]])
|
||||
n_events <- min(event_counts)
|
||||
n_predictors <- length(all_vars)
|
||||
|
||||
log_add(glue("结局变量分布: {paste(names(event_counts), '=', event_counts, collapse=', ')}"))
|
||||
log_add(glue("事件数: {n_events}, 预测变量数: {n_predictors}"))
|
||||
|
||||
# ===== 护栏检查 =====
|
||||
guardrail_results <- list()
|
||||
warnings_list <- c()
|
||||
|
||||
# EPV 规则检查(Events Per Variable >= 10)
|
||||
epv <- n_events / n_predictors
|
||||
if (epv < 10) {
|
||||
warnings_list <- c(warnings_list, glue("EPV = {round(epv, 1)} < 10,模型可能不稳定"))
|
||||
log_add(glue("警告: EPV = {round(epv, 1)} < 10"))
|
||||
}
|
||||
|
||||
# 样本量检查
|
||||
sample_check <- check_sample_size(nrow(df), min_required = 20, action = ACTION_BLOCK)
|
||||
guardrail_results <- c(guardrail_results, list(sample_check))
|
||||
|
||||
guardrail_status <- run_guardrail_chain(guardrail_results)
|
||||
|
||||
if (guardrail_status$status == "blocked") {
|
||||
return(list(
|
||||
status = "blocked",
|
||||
message = guardrail_status$reason,
|
||||
trace_log = logs
|
||||
))
|
||||
}
|
||||
|
||||
# ===== 构建模型公式 =====
|
||||
formula_str <- paste(outcome_var, "~", paste(all_vars, collapse = " + "))
|
||||
formula_obj <- as.formula(formula_str)
|
||||
log_add(glue("模型公式: {formula_str}"))
|
||||
|
||||
# ===== 核心计算 =====
|
||||
log_add("拟合 Logistic 回归模型")
|
||||
|
||||
model <- tryCatch({
|
||||
glm(formula_obj, data = df, family = binomial(link = "logit"))
|
||||
}, error = function(e) {
|
||||
log_add(paste("模型拟合失败:", e$message))
|
||||
return(NULL)
|
||||
}, warning = function(w) {
|
||||
warnings_list <<- c(warnings_list, w$message)
|
||||
log_add(paste("模型警告:", w$message))
|
||||
invokeRestart("muffleWarning")
|
||||
})
|
||||
|
||||
if (is.null(model)) {
|
||||
return(map_r_error("模型拟合失败"))
|
||||
}
|
||||
|
||||
# 检查模型收敛
|
||||
if (!model$converged) {
|
||||
warnings_list <- c(warnings_list, "模型未完全收敛")
|
||||
log_add("警告: 模型未完全收敛")
|
||||
}
|
||||
|
||||
# ===== 提取模型结果 =====
|
||||
coef_summary <- summary(model)$coefficients
|
||||
|
||||
# 计算 OR 和 95% CI
|
||||
coef_table <- data.frame(
|
||||
variable = rownames(coef_summary),
|
||||
estimate = coef_summary[, "Estimate"],
|
||||
std_error = coef_summary[, "Std. Error"],
|
||||
z_value = coef_summary[, "z value"],
|
||||
p_value = coef_summary[, "Pr(>|z|)"],
|
||||
stringsAsFactors = FALSE
|
||||
)
|
||||
|
||||
coef_table$OR <- exp(coef_table$estimate)
|
||||
coef_table$ci_lower <- exp(coef_table$estimate - 1.96 * coef_table$std_error)
|
||||
coef_table$ci_upper <- exp(coef_table$estimate + 1.96 * coef_table$std_error)
|
||||
|
||||
# 转换为列表格式(精简,不含原始系数)
|
||||
coefficients_list <- lapply(1:nrow(coef_table), function(i) {
|
||||
row <- coef_table[i, ]
|
||||
list(
|
||||
variable = row$variable,
|
||||
OR = round(row$OR, 3),
|
||||
ci_lower = round(row$ci_lower, 3),
|
||||
ci_upper = round(row$ci_upper, 3),
|
||||
p_value = round(row$p_value, 4),
|
||||
p_value_fmt = format_p_value(row$p_value),
|
||||
significant = row$p_value < 0.05
|
||||
)
|
||||
})
|
||||
|
||||
# ===== 模型拟合度 =====
|
||||
null_deviance <- model$null.deviance
|
||||
residual_deviance <- model$deviance
|
||||
aic <- AIC(model)
|
||||
|
||||
# Nagelkerke R²(伪 R²)
|
||||
n <- nrow(df)
|
||||
r2_nagelkerke <- (1 - exp((residual_deviance - null_deviance) / n)) / (1 - exp(-null_deviance / n))
|
||||
|
||||
log_add(glue("AIC = {round(aic, 2)}, Nagelkerke R² = {round(r2_nagelkerke, 3)}"))
|
||||
|
||||
# ===== 共线性检测(VIF) =====
|
||||
vif_results <- NULL
|
||||
if (length(all_vars) > 1) {
|
||||
tryCatch({
|
||||
if (requireNamespace("car", quietly = TRUE)) {
|
||||
vif_values <- car::vif(model)
|
||||
if (is.matrix(vif_values)) {
|
||||
vif_values <- vif_values[, "GVIF"]
|
||||
}
|
||||
vif_results <- lapply(names(vif_values), function(v) {
|
||||
list(variable = v, vif = round(vif_values[v], 2))
|
||||
})
|
||||
|
||||
high_vif <- names(vif_values)[vif_values > 5]
|
||||
if (length(high_vif) > 0) {
|
||||
warnings_list <- c(warnings_list, paste("VIF > 5 的变量:", paste(high_vif, collapse = ", ")))
|
||||
}
|
||||
}
|
||||
}, error = function(e) {
|
||||
log_add(paste("VIF 计算失败:", e$message))
|
||||
})
|
||||
}
|
||||
|
||||
# ===== 生成图表(森林图) =====
|
||||
log_add("生成森林图")
|
||||
plot_base64 <- tryCatch({
|
||||
generate_forest_plot(coef_table)
|
||||
}, error = function(e) {
|
||||
log_add(paste("图表生成失败:", e$message))
|
||||
NULL
|
||||
})
|
||||
|
||||
# ===== 生成可复现代码 =====
|
||||
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
|
||||
input$original_filename
|
||||
} else {
|
||||
"data.csv"
|
||||
}
|
||||
|
||||
reproducible_code <- glue('
|
||||
# SSA-Pro 自动生成代码
|
||||
# 工具: 二元 Logistic 回归
|
||||
# 时间: {Sys.time()}
|
||||
# ================================
|
||||
|
||||
# 数据准备
|
||||
df <- read.csv("{original_filename}")
|
||||
|
||||
# 模型拟合
|
||||
model <- glm({formula_str}, data = df, family = binomial(link = "logit"))
|
||||
summary(model)
|
||||
|
||||
# OR 和 95% CI
|
||||
coef_summary <- summary(model)$coefficients
|
||||
OR <- exp(coef_summary[, "Estimate"])
|
||||
CI_lower <- exp(coef_summary[, "Estimate"] - 1.96 * coef_summary[, "Std. Error"])
|
||||
CI_upper <- exp(coef_summary[, "Estimate"] + 1.96 * coef_summary[, "Std. Error"])
|
||||
results <- data.frame(OR = OR, CI_lower = CI_lower, CI_upper = CI_upper,
|
||||
p_value = coef_summary[, "Pr(>|z|)"])
|
||||
print(round(results, 3))
|
||||
|
||||
# 模型拟合度
|
||||
cat("AIC:", AIC(model), "\\n")
|
||||
|
||||
# VIF(需要 car 包)
|
||||
# library(car)
|
||||
# vif(model)
|
||||
')
|
||||
|
||||
# ===== 返回结果 =====
|
||||
log_add("分析完成")
|
||||
|
||||
return(list(
|
||||
status = "success",
|
||||
message = "分析完成",
|
||||
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
|
||||
results = list(
|
||||
method = "Binary Logistic Regression (glm, binomial)",
|
||||
formula = formula_str,
|
||||
n_observations = nrow(df),
|
||||
n_predictors = n_predictors,
|
||||
coefficients = coefficients_list,
|
||||
model_fit = list(
|
||||
aic = jsonlite::unbox(round(aic, 2)),
|
||||
null_deviance = jsonlite::unbox(round(null_deviance, 2)),
|
||||
residual_deviance = jsonlite::unbox(round(residual_deviance, 2)),
|
||||
r2_nagelkerke = jsonlite::unbox(round(r2_nagelkerke, 4))
|
||||
),
|
||||
vif = vif_results,
|
||||
epv = jsonlite::unbox(round(epv, 1))
|
||||
),
|
||||
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
|
||||
trace_log = logs,
|
||||
reproducible_code = as.character(reproducible_code)
|
||||
))
|
||||
}
|
||||
|
||||
# 辅助函数:生成森林图
|
||||
generate_forest_plot <- function(coef_table) {
|
||||
# 移除截距项
|
||||
plot_data <- coef_table[coef_table$variable != "(Intercept)", ]
|
||||
|
||||
if (nrow(plot_data) == 0) {
|
||||
return(NULL)
|
||||
}
|
||||
|
||||
plot_data$variable <- factor(plot_data$variable, levels = rev(plot_data$variable))
|
||||
|
||||
p <- ggplot(plot_data, aes(x = OR, y = variable)) +
|
||||
geom_vline(xintercept = 1, linetype = "dashed", color = "gray50") +
|
||||
geom_point(size = 3, color = "#3b82f6") +
|
||||
geom_errorbarh(aes(xmin = ci_lower, xmax = ci_upper), height = 0.2, color = "#3b82f6") +
|
||||
scale_x_log10() +
|
||||
theme_minimal() +
|
||||
labs(
|
||||
title = "Forest Plot: Odds Ratios with 95% CI",
|
||||
x = "Odds Ratio (log scale)",
|
||||
y = "Variable"
|
||||
) +
|
||||
theme(
|
||||
panel.grid.minor = element_blank(),
|
||||
axis.text.y = element_text(size = 10)
|
||||
)
|
||||
|
||||
tmp_file <- tempfile(fileext = ".png")
|
||||
ggsave(tmp_file, p, width = 8, height = max(4, nrow(plot_data) * 0.5 + 2), dpi = 100)
|
||||
base64_str <- base64encode(tmp_file)
|
||||
unlink(tmp_file)
|
||||
|
||||
return(paste0("data:image/png;base64,", base64_str))
|
||||
}
|
||||
Reference in New Issue
Block a user