Files
AIclinicalresearch/r-statistics-service/tools/logistic_binary.R
HaHafeng 428a22adf2 feat(ssa): Complete Phase 2A frontend integration - multi-step workflow end-to-end
Phase 2A: WorkflowPlannerService, WorkflowExecutorService, Python data quality, 6 bug fixes, DescriptiveResultView, multi-step R code/Word export, MVP UI reuse. V11 UI: Gemini-style, multi-task, single-page scroll, Word export. Architecture: Block-based rendering consensus (4 block types). New R tools: chi_square, correlation, descriptive, logistic_binary, mann_whitney, t_test_paired. Docs: dev summary, block-based plan, status updates, task list v2.0.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-20 23:09:27 +08:00

317 lines
9.5 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#' @tool_code ST_LOGISTIC_BINARY
#' @name 二元 Logistic 回归
#' @version 1.0.0
#' @description 二分类结局变量的多因素分析
#' @author SSA-Pro Team
library(glue)
library(ggplot2)
library(base64enc)
run_analysis <- function(input) {
# ===== 初始化 =====
logs <- c()
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
on.exit({}, add = TRUE)
# ===== 数据加载 =====
log_add("开始加载输入数据")
df <- tryCatch(
load_input_data(input),
error = function(e) {
log_add(paste("数据加载失败:", e$message))
return(NULL)
}
)
if (is.null(df)) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
}
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
p <- input$params
outcome_var <- p$outcome_var
predictors <- p$predictors # 预测变量列表
confounders <- p$confounders # 混杂因素(可选)
# ===== 参数校验 =====
if (!(outcome_var %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = outcome_var))
}
all_vars <- c(predictors, confounders)
all_vars <- all_vars[!is.null(all_vars) & all_vars != ""]
for (v in all_vars) {
if (!(v %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = v))
}
}
if (length(predictors) == 0) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "至少需要一个预测变量"))
}
# ===== 数据清洗 =====
original_rows <- nrow(df)
# 移除所有相关变量的缺失值
vars_to_check <- c(outcome_var, all_vars)
for (v in vars_to_check) {
df <- df[!is.na(df[[v]]), ]
}
removed_rows <- original_rows - nrow(df)
if (removed_rows > 0) {
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
}
# ===== 结局变量检查 =====
outcome_values <- unique(df[[outcome_var]])
if (length(outcome_values) != 2) {
return(make_error(ERROR_CODES$E003_INSUFFICIENT_GROUPS,
col = outcome_var, expected = 2, actual = length(outcome_values)))
}
# 确保结局变量是 0/1 或因子
if (!is.factor(df[[outcome_var]])) {
df[[outcome_var]] <- as.factor(df[[outcome_var]])
}
# 事件数统计
event_counts <- table(df[[outcome_var]])
n_events <- min(event_counts)
n_predictors <- length(all_vars)
log_add(glue("结局变量分布: {paste(names(event_counts), '=', event_counts, collapse=', ')}"))
log_add(glue("事件数: {n_events}, 预测变量数: {n_predictors}"))
# ===== 护栏检查 =====
guardrail_results <- list()
warnings_list <- c()
# EPV 规则检查Events Per Variable >= 10
epv <- n_events / n_predictors
if (epv < 10) {
warnings_list <- c(warnings_list, glue("EPV = {round(epv, 1)} < 10模型可能不稳定"))
log_add(glue("警告: EPV = {round(epv, 1)} < 10"))
}
# 样本量检查
sample_check <- check_sample_size(nrow(df), min_required = 20, action = ACTION_BLOCK)
guardrail_results <- c(guardrail_results, list(sample_check))
guardrail_status <- run_guardrail_chain(guardrail_results)
if (guardrail_status$status == "blocked") {
return(list(
status = "blocked",
message = guardrail_status$reason,
trace_log = logs
))
}
# ===== 构建模型公式 =====
formula_str <- paste(outcome_var, "~", paste(all_vars, collapse = " + "))
formula_obj <- as.formula(formula_str)
log_add(glue("模型公式: {formula_str}"))
# ===== 核心计算 =====
log_add("拟合 Logistic 回归模型")
model <- tryCatch({
glm(formula_obj, data = df, family = binomial(link = "logit"))
}, error = function(e) {
log_add(paste("模型拟合失败:", e$message))
return(NULL)
}, warning = function(w) {
warnings_list <<- c(warnings_list, w$message)
log_add(paste("模型警告:", w$message))
invokeRestart("muffleWarning")
})
if (is.null(model)) {
return(map_r_error("模型拟合失败"))
}
# 检查模型收敛
if (!model$converged) {
warnings_list <- c(warnings_list, "模型未完全收敛")
log_add("警告: 模型未完全收敛")
}
# ===== 提取模型结果 =====
coef_summary <- summary(model)$coefficients
# 计算 OR 和 95% CI
coef_table <- data.frame(
variable = rownames(coef_summary),
estimate = coef_summary[, "Estimate"],
std_error = coef_summary[, "Std. Error"],
z_value = coef_summary[, "z value"],
p_value = coef_summary[, "Pr(>|z|)"],
stringsAsFactors = FALSE
)
coef_table$OR <- exp(coef_table$estimate)
coef_table$ci_lower <- exp(coef_table$estimate - 1.96 * coef_table$std_error)
coef_table$ci_upper <- exp(coef_table$estimate + 1.96 * coef_table$std_error)
# 转换为列表格式(精简,不含原始系数)
coefficients_list <- lapply(1:nrow(coef_table), function(i) {
row <- coef_table[i, ]
list(
variable = row$variable,
OR = round(row$OR, 3),
ci_lower = round(row$ci_lower, 3),
ci_upper = round(row$ci_upper, 3),
p_value = round(row$p_value, 4),
p_value_fmt = format_p_value(row$p_value),
significant = row$p_value < 0.05
)
})
# ===== 模型拟合度 =====
null_deviance <- model$null.deviance
residual_deviance <- model$deviance
aic <- AIC(model)
# Nagelkerke R²伪 R²
n <- nrow(df)
r2_nagelkerke <- (1 - exp((residual_deviance - null_deviance) / n)) / (1 - exp(-null_deviance / n))
log_add(glue("AIC = {round(aic, 2)}, Nagelkerke R² = {round(r2_nagelkerke, 3)}"))
# ===== 共线性检测VIF =====
vif_results <- NULL
if (length(all_vars) > 1) {
tryCatch({
if (requireNamespace("car", quietly = TRUE)) {
vif_values <- car::vif(model)
if (is.matrix(vif_values)) {
vif_values <- vif_values[, "GVIF"]
}
vif_results <- lapply(names(vif_values), function(v) {
list(variable = v, vif = round(vif_values[v], 2))
})
high_vif <- names(vif_values)[vif_values > 5]
if (length(high_vif) > 0) {
warnings_list <- c(warnings_list, paste("VIF > 5 的变量:", paste(high_vif, collapse = ", ")))
}
}
}, error = function(e) {
log_add(paste("VIF 计算失败:", e$message))
})
}
# ===== 生成图表(森林图) =====
log_add("生成森林图")
plot_base64 <- tryCatch({
generate_forest_plot(coef_table)
}, error = function(e) {
log_add(paste("图表生成失败:", e$message))
NULL
})
# ===== 生成可复现代码 =====
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
input$original_filename
} else {
"data.csv"
}
reproducible_code <- glue('
# SSA-Pro 自动生成代码
# 工具: 二元 Logistic 回归
# 时间: {Sys.time()}
# ================================
# 数据准备
df <- read.csv("{original_filename}")
# 模型拟合
model <- glm({formula_str}, data = df, family = binomial(link = "logit"))
summary(model)
# OR 和 95% CI
coef_summary <- summary(model)$coefficients
OR <- exp(coef_summary[, "Estimate"])
CI_lower <- exp(coef_summary[, "Estimate"] - 1.96 * coef_summary[, "Std. Error"])
CI_upper <- exp(coef_summary[, "Estimate"] + 1.96 * coef_summary[, "Std. Error"])
results <- data.frame(OR = OR, CI_lower = CI_lower, CI_upper = CI_upper,
p_value = coef_summary[, "Pr(>|z|)"])
print(round(results, 3))
# 模型拟合度
cat("AIC:", AIC(model), "\\n")
# VIF需要 car 包)
# library(car)
# vif(model)
')
# ===== 返回结果 =====
log_add("分析完成")
return(list(
status = "success",
message = "分析完成",
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
results = list(
method = "Binary Logistic Regression (glm, binomial)",
formula = formula_str,
n_observations = nrow(df),
n_predictors = n_predictors,
coefficients = coefficients_list,
model_fit = list(
aic = jsonlite::unbox(round(aic, 2)),
null_deviance = jsonlite::unbox(round(null_deviance, 2)),
residual_deviance = jsonlite::unbox(round(residual_deviance, 2)),
r2_nagelkerke = jsonlite::unbox(round(r2_nagelkerke, 4))
),
vif = vif_results,
epv = jsonlite::unbox(round(epv, 1))
),
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
trace_log = logs,
reproducible_code = as.character(reproducible_code)
))
}
# 辅助函数:生成森林图
generate_forest_plot <- function(coef_table) {
# 移除截距项
plot_data <- coef_table[coef_table$variable != "(Intercept)", ]
if (nrow(plot_data) == 0) {
return(NULL)
}
plot_data$variable <- factor(plot_data$variable, levels = rev(plot_data$variable))
p <- ggplot(plot_data, aes(x = OR, y = variable)) +
geom_vline(xintercept = 1, linetype = "dashed", color = "gray50") +
geom_point(size = 3, color = "#3b82f6") +
geom_errorbarh(aes(xmin = ci_lower, xmax = ci_upper), height = 0.2, color = "#3b82f6") +
scale_x_log10() +
theme_minimal() +
labs(
title = "Forest Plot: Odds Ratios with 95% CI",
x = "Odds Ratio (log scale)",
y = "Variable"
) +
theme(
panel.grid.minor = element_blank(),
axis.text.y = element_text(size = 10)
)
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, p, width = 8, height = max(4, nrow(plot_data) * 0.5 + 2), dpi = 100)
base64_str <- base64encode(tmp_file)
unlink(tmp_file)
return(paste0("data:image/png;base64,", base64_str))
}