Files
AIclinicalresearch/r-statistics-service/utils/guardrails.R

117 lines
3.0 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# utils/guardrails.R
# 统计护栏函数库
library(glue)
# 大样本优化阈值
LARGE_SAMPLE_THRESHOLD <- 5000
# 护栏 Action 类型
ACTION_BLOCK <- "Block" # 阻止执行
ACTION_WARN <- "Warn" # 警告但继续
ACTION_SWITCH <- "Switch" # 切换到备选方法
# 正态性检验(支持三种 Action
check_normality <- function(values, alpha = 0.05, action = ACTION_SWITCH, action_target = NULL) {
n <- length(values)
# 样本量过小
if (n < 3) {
return(list(
passed = TRUE,
action = NULL,
action_target = NULL,
reason = "样本量过小,跳过正态性检验",
skipped = TRUE
))
}
# 大样本优化N > 5000 时使用抽样检验
if (n > LARGE_SAMPLE_THRESHOLD) {
set.seed(42)
sampled_values <- sample(values, 1000)
test <- shapiro.test(sampled_values)
passed <- test$p.value >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
action_target = if (passed) NULL else action_target,
p_value = test$p.value,
reason = glue("大样本(N={n})抽样检验,{if (passed) '满足正态性' else '不满足正态性'}"),
sampled = TRUE,
sample_size = 1000
))
}
# 常规检验
test <- shapiro.test(values)
passed <- test$p.value >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
action_target = if (passed) NULL else action_target,
p_value = test$p.value,
reason = if (passed) "满足正态性" else "不满足正态性",
sampled = FALSE
))
}
# 方差齐性检验 (Levene)
check_homogeneity <- function(df, group_var, value_var, alpha = 0.05, action = ACTION_WARN) {
library(car)
formula <- as.formula(paste(value_var, "~", group_var))
test <- leveneTest(formula, data = df)
p_val <- test$`Pr(>F)`[1]
passed <- p_val >= alpha
return(list(
passed = passed,
action = if (passed) NULL else action,
p_value = p_val,
reason = if (passed) "方差齐性满足" else "方差不齐性"
))
}
# 样本量检验
check_sample_size <- function(n, min_required = 3, action = ACTION_BLOCK) {
passed <- n >= min_required
return(list(
passed = passed,
action = if (passed) NULL else action,
n = n,
reason = if (passed) "样本量充足" else paste0("样本量不足, 需要至少 ", min_required)
))
}
# 执行护栏链(按 check_order 顺序执行)
run_guardrail_chain <- function(guardrail_results) {
warnings <- c()
for (result in guardrail_results) {
if (!result$passed) {
if (result$action == ACTION_BLOCK) {
return(list(
status = "blocked",
reason = result$reason
))
} else if (result$action == ACTION_SWITCH) {
return(list(
status = "switch",
target_tool = result$action_target,
reason = result$reason
))
} else if (result$action == ACTION_WARN) {
warnings <- c(warnings, result$reason)
}
}
}
return(list(
status = "passed",
warnings = warnings
))
}