Files
AIclinicalresearch/r-statistics-service/tools/logistic_binary.R
HaHafeng 371e1c069c feat(ssa): Complete QPER architecture - Query, Planner, Execute, Reflection layers
Implement the full QPER intelligent analysis pipeline:

- Phase E+: Block-based standardization for all 7 R tools, DynamicReport renderer, Word export enhancement

- Phase Q: LLM intent parsing with dynamic Zod validation against real column names, ClarificationCard component, DataProfile is_id_like tagging

- Phase P: ConfigLoader with Zod schema validation and hot-reload API, DecisionTableService (4-dimension matching), FlowTemplateService with EPV protection, PlannedTrace audit output

- Phase R: ReflectionService with statistical slot injection, sensitivity analysis conflict rules, ConclusionReport with section reveal animation, conclusion caching API, graceful R error classification

End-to-end test: 40/40 passed across two complete analysis scenarios.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-21 18:15:53 +08:00

370 lines
12 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#' @tool_code ST_LOGISTIC_BINARY
#' @name 二元 Logistic 回归
#' @version 1.0.0
#' @description 二分类结局变量的多因素分析
#' @author SSA-Pro Team
library(glue)
library(ggplot2)
library(base64enc)
run_analysis <- function(input) {
# ===== 初始化 =====
logs <- c()
log_add <- function(msg) { logs <<- c(logs, paste0("[", Sys.time(), "] ", msg)) }
on.exit({}, add = TRUE)
# ===== 数据加载 =====
log_add("开始加载输入数据")
df <- tryCatch(
load_input_data(input),
error = function(e) {
log_add(paste("数据加载失败:", e$message))
return(NULL)
}
)
if (is.null(df)) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "数据加载失败"))
}
log_add(glue("数据加载成功: {nrow(df)} 行, {ncol(df)} 列"))
p <- input$params
outcome_var <- p$outcome_var
predictors <- p$predictors # 预测变量列表
confounders <- p$confounders # 混杂因素(可选)
# ===== 参数校验 =====
if (!(outcome_var %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = outcome_var))
}
all_vars <- c(predictors, confounders)
all_vars <- all_vars[!is.null(all_vars) & all_vars != ""]
for (v in all_vars) {
if (!(v %in% names(df))) {
return(make_error(ERROR_CODES$E001_COLUMN_NOT_FOUND, col = v))
}
}
if (length(predictors) == 0) {
return(make_error(ERROR_CODES$E100_INTERNAL_ERROR, details = "至少需要一个预测变量"))
}
# ===== 数据清洗 =====
original_rows <- nrow(df)
# 移除所有相关变量的缺失值
vars_to_check <- c(outcome_var, all_vars)
for (v in vars_to_check) {
df <- df[!is.na(df[[v]]), ]
}
removed_rows <- original_rows - nrow(df)
if (removed_rows > 0) {
log_add(glue("数据清洗: 移除 {removed_rows} 行缺失值 (剩余 {nrow(df)} 行)"))
}
# ===== 结局变量检查 =====
outcome_values <- unique(df[[outcome_var]])
if (length(outcome_values) != 2) {
return(make_error(ERROR_CODES$E003_INSUFFICIENT_GROUPS,
col = outcome_var, expected = 2, actual = length(outcome_values)))
}
# 确保结局变量是 0/1 或因子
if (!is.factor(df[[outcome_var]])) {
df[[outcome_var]] <- as.factor(df[[outcome_var]])
}
# 事件数统计
event_counts <- table(df[[outcome_var]])
n_events <- min(event_counts)
n_predictors <- length(all_vars)
log_add(glue("结局变量分布: {paste(names(event_counts), '=', event_counts, collapse=', ')}"))
log_add(glue("事件数: {n_events}, 预测变量数: {n_predictors}"))
# ===== 护栏检查 =====
guardrail_results <- list()
warnings_list <- c()
# EPV 规则检查Events Per Variable >= 10
epv <- n_events / n_predictors
if (epv < 10) {
warnings_list <- c(warnings_list, glue("EPV = {round(epv, 1)} < 10模型可能不稳定"))
log_add(glue("警告: EPV = {round(epv, 1)} < 10"))
}
# 样本量检查
sample_check <- check_sample_size(nrow(df), min_required = 20, action = ACTION_BLOCK)
guardrail_results <- c(guardrail_results, list(sample_check))
guardrail_status <- run_guardrail_chain(guardrail_results)
if (guardrail_status$status == "blocked") {
return(list(
status = "blocked",
message = guardrail_status$reason,
trace_log = logs
))
}
# ===== 构建模型公式 =====
formula_str <- paste(outcome_var, "~", paste(all_vars, collapse = " + "))
formula_obj <- as.formula(formula_str)
log_add(glue("模型公式: {formula_str}"))
# ===== 核心计算 =====
log_add("拟合 Logistic 回归模型")
model <- tryCatch({
glm(formula_obj, data = df, family = binomial(link = "logit"))
}, error = function(e) {
log_add(paste("模型拟合失败:", e$message))
return(NULL)
}, warning = function(w) {
warnings_list <<- c(warnings_list, w$message)
log_add(paste("模型警告:", w$message))
invokeRestart("muffleWarning")
})
if (is.null(model)) {
return(map_r_error("模型拟合失败"))
}
# 检查模型收敛
if (!model$converged) {
warnings_list <- c(warnings_list, "模型未完全收敛")
log_add("警告: 模型未完全收敛")
}
# ===== 提取模型结果 =====
coef_summary <- summary(model)$coefficients
# 计算 OR 和 95% CI
coef_table <- data.frame(
variable = rownames(coef_summary),
estimate = coef_summary[, "Estimate"],
std_error = coef_summary[, "Std. Error"],
z_value = coef_summary[, "z value"],
p_value = coef_summary[, "Pr(>|z|)"],
stringsAsFactors = FALSE
)
coef_table$OR <- exp(coef_table$estimate)
coef_table$ci_lower <- exp(coef_table$estimate - 1.96 * coef_table$std_error)
coef_table$ci_upper <- exp(coef_table$estimate + 1.96 * coef_table$std_error)
# 转换为列表格式(精简,不含原始系数)
coefficients_list <- lapply(1:nrow(coef_table), function(i) {
row <- coef_table[i, ]
list(
variable = row$variable,
OR = round(row$OR, 3),
ci_lower = round(row$ci_lower, 3),
ci_upper = round(row$ci_upper, 3),
p_value = round(row$p_value, 4),
p_value_fmt = format_p_value(row$p_value),
significant = row$p_value < 0.05
)
})
# ===== 模型拟合度 =====
null_deviance <- model$null.deviance
residual_deviance <- model$deviance
aic <- AIC(model)
# Nagelkerke R²伪 R²
n <- nrow(df)
r2_nagelkerke <- (1 - exp((residual_deviance - null_deviance) / n)) / (1 - exp(-null_deviance / n))
log_add(glue("AIC = {round(aic, 2)}, Nagelkerke R² = {round(r2_nagelkerke, 3)}"))
# ===== 共线性检测VIF =====
vif_results <- NULL
if (length(all_vars) > 1) {
tryCatch({
if (requireNamespace("car", quietly = TRUE)) {
vif_values <- car::vif(model)
if (is.matrix(vif_values)) {
vif_values <- vif_values[, "GVIF"]
}
vif_results <- lapply(names(vif_values), function(v) {
list(variable = v, vif = round(vif_values[v], 2))
})
high_vif <- names(vif_values)[vif_values > 5]
if (length(high_vif) > 0) {
warnings_list <- c(warnings_list, paste("VIF > 5 的变量:", paste(high_vif, collapse = ", ")))
}
}
}, error = function(e) {
log_add(paste("VIF 计算失败:", e$message))
})
}
# ===== 生成图表(森林图) =====
log_add("生成森林图")
plot_base64 <- tryCatch({
generate_forest_plot(coef_table)
}, error = function(e) {
log_add(paste("图表生成失败:", e$message))
NULL
})
# ===== 生成可复现代码 =====
original_filename <- if (!is.null(input$original_filename) && nchar(input$original_filename) > 0) {
input$original_filename
} else {
"data.csv"
}
reproducible_code <- glue('
# SSA-Pro 自动生成代码
# 工具: 二元 Logistic 回归
# 时间: {Sys.time()}
# ================================
# 数据准备
df <- read.csv("{original_filename}")
# 模型拟合
model <- glm({formula_str}, data = df, family = binomial(link = "logit"))
summary(model)
# OR 和 95% CI
coef_summary <- summary(model)$coefficients
OR <- exp(coef_summary[, "Estimate"])
CI_lower <- exp(coef_summary[, "Estimate"] - 1.96 * coef_summary[, "Std. Error"])
CI_upper <- exp(coef_summary[, "Estimate"] + 1.96 * coef_summary[, "Std. Error"])
results <- data.frame(OR = OR, CI_lower = CI_lower, CI_upper = CI_upper,
p_value = coef_summary[, "Pr(>|z|)"])
print(round(results, 3))
# 模型拟合度
cat("AIC:", AIC(model), "\\n")
# VIF需要 car 包)
# library(car)
# vif(model)
')
# ===== 返回结果 =====
log_add("分析完成")
# ===== 构建 report_blocks =====
blocks <- list()
# Block 1: 模型概况
blocks[[length(blocks) + 1]] <- make_kv_block(list(
"模型公式" = formula_str,
"观测数" = as.character(nrow(df)),
"预测变量数" = as.character(n_predictors),
"AIC" = as.character(round(aic, 2)),
"Nagelkerke R²" = as.character(round(r2_nagelkerke, 4)),
"EPV" = as.character(round(epv, 1))
), title = "模型概况")
# Block 2: 回归系数表
coef_headers <- c("变量", "OR", "95% CI", "P 值", "显著性")
coef_rows <- lapply(coefficients_list, function(row) {
ci_str <- sprintf("[%.3f, %.3f]", row$ci_lower, row$ci_upper)
sig <- if (row$significant) "*" else ""
c(row$variable, as.character(row$OR), ci_str, row$p_value_fmt, sig)
})
blocks[[length(blocks) + 1]] <- make_table_block(coef_headers, coef_rows, title = "回归系数表", footnote = "* P < 0.05")
# Block 3: VIF 表(如存在)
if (!is.null(vif_results) && length(vif_results) > 0) {
vif_headers <- c("变量", "VIF")
vif_rows <- lapply(vif_results, function(row) c(row$variable, as.character(row$vif)))
blocks[[length(blocks) + 1]] <- make_table_block(vif_headers, vif_rows, title = "方差膨胀因子 (VIF)")
}
# Block 4: 森林图(如存在)
if (!is.null(plot_base64)) {
blocks[[length(blocks) + 1]] <- make_image_block(plot_base64, title = "森林图", alt = "Odds Ratios Forest Plot")
}
# Block 5: 结论摘要
sig_vars <- sapply(coefficients_list, function(r) if (r$variable != "(Intercept)" && r$significant) r$variable else NULL)
sig_vars <- unlist(sig_vars[!sapply(sig_vars, is.null)])
conclusion_lines <- c(
glue("模型拟合指标AIC = {round(aic, 2)}Nagelkerke R² = {round(r2_nagelkerke, 4)}。"),
""
)
if (length(sig_vars) > 0) {
conclusion_lines <- c(conclusion_lines,
glue("α = 0.05 水平下,以下变量具有统计学意义:**{paste(sig_vars, collapse = '**, **')}**。"),
""
)
} else {
conclusion_lines <- c(conclusion_lines, "α = 0.05 水平下,无预测变量达到统计学意义。", "")
}
conclusion_lines <- c(conclusion_lines, glue("EPV = {round(epv, 1)}(建议 ≥ 10。"))
blocks[[length(blocks) + 1]] <- make_markdown_block(paste(conclusion_lines, collapse = "\n"), title = "结论摘要")
return(list(
status = "success",
message = "分析完成",
warnings = if (length(warnings_list) > 0) warnings_list else NULL,
results = list(
method = "Binary Logistic Regression (glm, binomial)",
formula = formula_str,
n_observations = nrow(df),
n_predictors = n_predictors,
coefficients = coefficients_list,
model_fit = list(
aic = jsonlite::unbox(round(aic, 2)),
null_deviance = jsonlite::unbox(round(null_deviance, 2)),
residual_deviance = jsonlite::unbox(round(residual_deviance, 2)),
r2_nagelkerke = jsonlite::unbox(round(r2_nagelkerke, 4))
),
vif = vif_results,
epv = jsonlite::unbox(round(epv, 1))
),
report_blocks = blocks,
plots = if (!is.null(plot_base64)) list(plot_base64) else list(),
trace_log = logs,
reproducible_code = as.character(reproducible_code)
))
}
# 辅助函数:生成森林图
generate_forest_plot <- function(coef_table) {
# 移除截距项
plot_data <- coef_table[coef_table$variable != "(Intercept)", ]
if (nrow(plot_data) == 0) {
return(NULL)
}
plot_data$variable <- factor(plot_data$variable, levels = rev(plot_data$variable))
p <- ggplot(plot_data, aes(x = OR, y = variable)) +
geom_vline(xintercept = 1, linetype = "dashed", color = "gray50") +
geom_point(size = 3, color = "#3b82f6") +
geom_errorbarh(aes(xmin = ci_lower, xmax = ci_upper), height = 0.2, color = "#3b82f6") +
scale_x_log10() +
theme_minimal() +
labs(
title = "Forest Plot: Odds Ratios with 95% CI",
x = "Odds Ratio (log scale)",
y = "Variable"
) +
theme(
panel.grid.minor = element_blank(),
axis.text.y = element_text(size = 10)
)
tmp_file <- tempfile(fileext = ".png")
ggsave(tmp_file, p, width = 8, height = max(4, nrow(plot_data) * 0.5 + 2), dpi = 100)
base64_str <- base64encode(tmp_file)
unlink(tmp_file)
return(paste0("data:image/png;base64,", base64_str))
}