Files
AIclinicalresearch/r-statistics-service/utils/guardrails.R
HaHafeng 428a22adf2 feat(ssa): Complete Phase 2A frontend integration - multi-step workflow end-to-end
Phase 2A: WorkflowPlannerService, WorkflowExecutorService, Python data quality, 6 bug fixes, DescriptiveResultView, multi-step R code/Word export, MVP UI reuse. V11 UI: Gemini-style, multi-task, single-page scroll, Word export. Architecture: Block-based rendering consensus (4 block types). New R tools: chi_square, correlation, descriptive, logistic_binary, mann_whitney, t_test_paired. Docs: dev summary, block-based plan, status updates, task list v2.0.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-20 23:09:27 +08:00

243 lines
7.3 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# utils/guardrails.R
# 统计护栏函数库
library(glue)
# 大样本优化阈值
LARGE_SAMPLE_THRESHOLD <- 5000
# 护栏 Action 类型
ACTION_BLOCK <- "Block" # 阻止执行
ACTION_WARN <- "Warn" # 警告但继续
ACTION_SWITCH <- "Switch" # 切换到备选方法
# 正态性检验(支持三种 Action
check_normality <- function(values, alpha = 0.05, action = ACTION_SWITCH, action_target = NULL) {
n <- length(values)
# 样本量过小
if (n < 3) {
return(list(
passed = TRUE,
action = NULL,
action_target = NULL,
reason = "样本量过小,跳过正态性检验",
skipped = TRUE
))
}
# 大样本优化N > 5000 时使用抽样检验
if (n > LARGE_SAMPLE_THRESHOLD) {
set.seed(42)
sampled_values <- sample(values, 1000)
test <- shapiro.test(sampled_values)
passed <- test$p.value >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
action_target = if (passed) NULL else action_target,
p_value = test$p.value,
reason = glue("大样本(N={n})抽样检验,{if (passed) '满足正态性' else '不满足正态性'}"),
sampled = TRUE,
sample_size = 1000
))
}
# 常规检验
test <- shapiro.test(values)
passed <- test$p.value >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
action_target = if (passed) NULL else action_target,
p_value = test$p.value,
reason = if (passed) "满足正态性" else "不满足正态性",
sampled = FALSE
))
}
# 方差齐性检验 (Levene)
check_homogeneity <- function(df, group_var, value_var, alpha = 0.05, action = ACTION_WARN) {
library(car)
formula <- as.formula(paste(value_var, "~", group_var))
test <- leveneTest(formula, data = df)
p_val <- test$`Pr(>F)`[1]
passed <- p_val >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
p_value = p_val,
reason = if (passed) "方差齐性满足" else "方差不齐性"
))
}
# 样本量检验
check_sample_size <- function(n, min_required = 3, action = ACTION_BLOCK) {
passed <- n >= min_required
return(list(
passed = passed,
action = if (passed) NULL else action,
n = n,
reason = if (passed) "样本量充足" else paste0("样本量不足, 需要至少 ", min_required)
))
}
# 执行护栏链(按 check_order 顺序执行)
run_guardrail_chain <- function(guardrail_results) {
warnings <- c()
for (result in guardrail_results) {
if (!result$passed) {
if (result$action == ACTION_BLOCK) {
return(list(
status = "blocked",
reason = result$reason
))
} else if (result$action == ACTION_SWITCH) {
return(list(
status = "switch",
target_tool = result$action_target,
reason = result$reason
))
} else if (result$action == ACTION_WARN) {
warnings <- c(warnings, result$reason)
}
}
}
return(list(
status = "passed",
warnings = warnings
))
}
# ========== JIT 护栏接口Phase 2A ==========
# 用于 WorkflowExecutor 在执行核心工具前调用
#' JIT 护栏检查:执行核心统计前检验假设
#' @param df 数据框
#' @param tool_code 目标工具代码
#' @param params 工具参数group_var, value_var 等)
#' @return list(checks, suggested_tool, can_proceed)
run_jit_guardrails <- function(df, tool_code, params) {
checks <- list()
suggested_tool <- tool_code
can_proceed <- TRUE
# 根据工具类型执行不同的检验
if (tool_code %in% c("ST_T_TEST_IND", "ST_MANN_WHITNEY")) {
# 独立样本比较:需要正态性 + 方差齐性检验
group_var <- params$group_var
value_var <- params$value_var
if (!is.null(group_var) && !is.null(value_var)) {
groups <- unique(df[[group_var]])
# 正态性检验(分组)
for (g in groups) {
vals <- df[df[[group_var]] == g, value_var]
if (length(vals) >= 3) {
norm_result <- check_normality(vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = glue("正态性检验 (组: {g})"),
passed = norm_result$passed,
p_value = norm_result$p_value,
recommendation = if (norm_result$passed) "满足正态性" else "建议使用非参数方法"
)))
if (!norm_result$passed && tool_code == "ST_T_TEST_IND") {
suggested_tool <- "ST_MANN_WHITNEY"
}
}
}
# 方差齐性检验
if (length(groups) == 2) {
tryCatch({
homo_result <- check_homogeneity(df, group_var, value_var, alpha = 0.05)
checks <- c(checks, list(list(
check_name = "方差齐性检验 (Levene)",
passed = homo_result$passed,
p_value = homo_result$p_value,
recommendation = if (homo_result$passed) "方差齐性满足" else "建议使用 Welch 校正"
)))
}, error = function(e) {
message("方差齐性检验失败: ", e$message)
})
}
}
} else if (tool_code == "ST_T_TEST_PAIRED") {
# 配对检验:需要差值正态性检验
before_var <- params$before_var
after_var <- params$after_var
if (!is.null(before_var) && !is.null(after_var)) {
diff_vals <- df[[after_var]] - df[[before_var]]
diff_vals <- diff_vals[!is.na(diff_vals)]
if (length(diff_vals) >= 3) {
norm_result <- check_normality(diff_vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = "差值正态性检验",
passed = norm_result$passed,
p_value = norm_result$p_value,
recommendation = if (norm_result$passed) "差值满足正态性" else "建议使用 Wilcoxon 符号秩检验"
)))
if (!norm_result$passed) {
suggested_tool <- "Wilcoxon signed-rank test"
}
}
}
} else if (tool_code == "ST_CORRELATION") {
# 相关分析:需要双变量正态性检验
var_x <- params$var_x
var_y <- params$var_y
if (!is.null(var_x) && !is.null(var_y)) {
x_vals <- df[[var_x]][!is.na(df[[var_x]])]
y_vals <- df[[var_y]][!is.na(df[[var_y]])]
if (length(x_vals) >= 3) {
norm_x <- check_normality(x_vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = glue("正态性检验 ({var_x})"),
passed = norm_x$passed,
p_value = norm_x$p_value,
recommendation = if (norm_x$passed) "满足正态性" else "建议使用 Spearman 秩相关"
)))
}
if (length(y_vals) >= 3) {
norm_y <- check_normality(y_vals, alpha = 0.05)
checks <- c(checks, list(list(
check_name = glue("正态性检验 ({var_y})"),
passed = norm_y$passed,
p_value = norm_y$p_value,
recommendation = if (norm_y$passed) "满足正态性" else "建议使用 Spearman 秩相关"
)))
if (!norm_x$passed || !norm_y$passed) {
suggested_tool <- "ST_CORRELATION (Spearman)"
}
}
}
}
# 汇总
all_passed <- all(sapply(checks, function(c) c$passed))
return(list(
checks = checks,
suggested_tool = suggested_tool,
can_proceed = TRUE, # 即使检验不通过也允许继续,由用户/LLM 决定
all_checks_passed = all_passed
))
}