117 lines
3.0 KiB
R
117 lines
3.0 KiB
R
# utils/guardrails.R
|
||
# 统计护栏函数库
|
||
|
||
library(glue)
|
||
|
||
# 大样本优化阈值
|
||
LARGE_SAMPLE_THRESHOLD <- 5000
|
||
|
||
# 护栏 Action 类型
|
||
ACTION_BLOCK <- "Block" # 阻止执行
|
||
ACTION_WARN <- "Warn" # 警告但继续
|
||
ACTION_SWITCH <- "Switch" # 切换到备选方法
|
||
|
||
# 正态性检验(支持三种 Action)
|
||
check_normality <- function(values, alpha = 0.05, action = ACTION_SWITCH, action_target = NULL) {
|
||
n <- length(values)
|
||
|
||
# 样本量过小
|
||
if (n < 3) {
|
||
return(list(
|
||
passed = TRUE,
|
||
action = NULL,
|
||
action_target = NULL,
|
||
reason = "样本量过小,跳过正态性检验",
|
||
skipped = TRUE
|
||
))
|
||
}
|
||
|
||
# 大样本优化:N > 5000 时使用抽样检验
|
||
if (n > LARGE_SAMPLE_THRESHOLD) {
|
||
set.seed(42)
|
||
sampled_values <- sample(values, 1000)
|
||
test <- shapiro.test(sampled_values)
|
||
passed <- test$p.value >= alpha
|
||
|
||
return(list(
|
||
passed = passed,
|
||
action = if (passed) NULL else action,
|
||
action_target = if (passed) NULL else action_target,
|
||
p_value = test$p.value,
|
||
reason = glue("大样本(N={n})抽样检验,{if (passed) '满足正态性' else '不满足正态性'}"),
|
||
sampled = TRUE,
|
||
sample_size = 1000
|
||
))
|
||
}
|
||
|
||
# 常规检验
|
||
test <- shapiro.test(values)
|
||
passed <- test$p.value >= alpha
|
||
|
||
return(list(
|
||
passed = passed,
|
||
action = if (passed) NULL else action,
|
||
action_target = if (passed) NULL else action_target,
|
||
p_value = test$p.value,
|
||
reason = if (passed) "满足正态性" else "不满足正态性",
|
||
sampled = FALSE
|
||
))
|
||
}
|
||
|
||
# 方差齐性检验 (Levene)
|
||
check_homogeneity <- function(df, group_var, value_var, alpha = 0.05, action = ACTION_WARN) {
|
||
library(car)
|
||
|
||
formula <- as.formula(paste(value_var, "~", group_var))
|
||
test <- leveneTest(formula, data = df)
|
||
p_val <- test$`Pr(>F)`[1]
|
||
passed <- p_val >= alpha
|
||
|
||
return(list(
|
||
passed = passed,
|
||
action = if (passed) NULL else action,
|
||
p_value = p_val,
|
||
reason = if (passed) "方差齐性满足" else "方差不齐性"
|
||
))
|
||
}
|
||
|
||
# 样本量检验
|
||
check_sample_size <- function(n, min_required = 3, action = ACTION_BLOCK) {
|
||
passed <- n >= min_required
|
||
return(list(
|
||
passed = passed,
|
||
action = if (passed) NULL else action,
|
||
n = n,
|
||
reason = if (passed) "样本量充足" else paste0("样本量不足, 需要至少 ", min_required)
|
||
))
|
||
}
|
||
|
||
# 执行护栏链(按 check_order 顺序执行)
|
||
run_guardrail_chain <- function(guardrail_results) {
|
||
warnings <- c()
|
||
|
||
for (result in guardrail_results) {
|
||
if (!result$passed) {
|
||
if (result$action == ACTION_BLOCK) {
|
||
return(list(
|
||
status = "blocked",
|
||
reason = result$reason
|
||
))
|
||
} else if (result$action == ACTION_SWITCH) {
|
||
return(list(
|
||
status = "switch",
|
||
target_tool = result$action_target,
|
||
reason = result$reason
|
||
))
|
||
} else if (result$action == ACTION_WARN) {
|
||
warnings <- c(warnings, result$reason)
|
||
}
|
||
}
|
||
}
|
||
|
||
return(list(
|
||
status = "passed",
|
||
warnings = warnings
|
||
))
|
||
}
|