Phase I - Session Blackboard + READ Layer: - SessionBlackboardService with Postgres-Only cache - DataProfileService for data overview generation - PicoInferenceService for LLM-driven PICO extraction - Frontend DataContextCard and VariableDictionaryPanel - E2E tests: 31/31 passed Phase II - Conversation Layer LLM + Intent Router: - ConversationService with SSE streaming - IntentRouterService (rule-first + LLM fallback, 6 intents) - SystemPromptService with 6-segment dynamic assembly - TokenTruncationService for context management - ChatHandlerService as unified chat entry - Frontend SSAChatPane and useSSAChat hook - E2E tests: 38/38 passed Phase III - Method Consultation + AskUser Standardization: - ToolRegistryService with Repository Pattern - MethodConsultService with DecisionTable + LLM enhancement - AskUserService with global interrupt handling - Frontend AskUserCard component - E2E tests: 13/13 passed Phase IV - Dialogue-Driven Analysis + QPER Integration: - ToolOrchestratorService (plan/execute/report) - analysis_plan SSE event for WorkflowPlan transmission - Dual-channel confirmation (ask_user card + workspace button) - PICO as optional hint for LLM parsing - E2E tests: 25/25 passed R Statistics Service: - 5 new R tools: anova_one, baseline_table, fisher, linear_reg, wilcoxon - Enhanced guardrails and block helpers - Comprehensive test suite (run_all_tools_test.js) Documentation: - Updated system status document (v5.9) - Updated SSA module status and development plan (v1.8) Total E2E: 107/107 passed (Phase I: 31, Phase II: 38, Phase III: 13, Phase IV: 25) Co-authored-by: Cursor <cursoragent@cursor.com>
335 lines
10 KiB
R
335 lines
10 KiB
R
# utils/guardrails.R
|
||
# 统计护栏函数库
|
||
|
||
library(glue)
|
||
|
||
# 大样本优化阈值
|
||
LARGE_SAMPLE_THRESHOLD <- 5000
|
||
|
||
# 护栏 Action 类型
|
||
ACTION_BLOCK <- "Block" # 阻止执行
|
||
ACTION_WARN <- "Warn" # 警告但继续
|
||
ACTION_SWITCH <- "Switch" # 切换到备选方法
|
||
|
||
# 正态性检验(支持三种 Action)
|
||
check_normality <- function(values, alpha = 0.05, action = ACTION_SWITCH, action_target = NULL) {
|
||
n <- length(values)
|
||
|
||
# 样本量过小
|
||
if (n < 3) {
|
||
return(list(
|
||
passed = TRUE,
|
||
action = NULL,
|
||
action_target = NULL,
|
||
reason = "样本量过小,跳过正态性检验",
|
||
skipped = TRUE
|
||
))
|
||
}
|
||
|
||
# 大样本优化:N > 5000 时使用抽样检验
|
||
if (n > LARGE_SAMPLE_THRESHOLD) {
|
||
set.seed(42)
|
||
sampled_values <- sample(values, 1000)
|
||
test <- shapiro.test(sampled_values)
|
||
passed <- test$p.value >= alpha
|
||
|
||
return(list(
|
||
passed = passed,
|
||
action = if (passed) NULL else action,
|
||
action_target = if (passed) NULL else action_target,
|
||
p_value = test$p.value,
|
||
reason = glue("大样本(N={n})抽样检验,{if (passed) '满足正态性' else '不满足正态性'}"),
|
||
sampled = TRUE,
|
||
sample_size = 1000
|
||
))
|
||
}
|
||
|
||
# 常规检验
|
||
test <- shapiro.test(values)
|
||
passed <- test$p.value >= alpha
|
||
|
||
return(list(
|
||
passed = passed,
|
||
action = if (passed) NULL else action,
|
||
action_target = if (passed) NULL else action_target,
|
||
p_value = test$p.value,
|
||
reason = if (passed) "满足正态性" else "不满足正态性",
|
||
sampled = FALSE
|
||
))
|
||
}
|
||
|
||
# 方差齐性检验 (Levene)
|
||
check_homogeneity <- function(df, group_var, value_var, alpha = 0.05, action = ACTION_WARN) {
|
||
library(car)
|
||
|
||
formula <- as.formula(paste(value_var, "~", group_var))
|
||
test <- leveneTest(formula, data = df)
|
||
p_val <- test$`Pr(>F)`[1]
|
||
passed <- p_val >= alpha
|
||
|
||
return(list(
|
||
passed = passed,
|
||
action = if (passed) NULL else action,
|
||
p_value = p_val,
|
||
reason = if (passed) "方差齐性满足" else "方差不齐性"
|
||
))
|
||
}
|
||
|
||
# 样本量检验
|
||
check_sample_size <- function(n, min_required = 3, action = ACTION_BLOCK) {
|
||
passed <- n >= min_required
|
||
return(list(
|
||
passed = passed,
|
||
action = if (passed) NULL else action,
|
||
n = n,
|
||
reason = if (passed) "样本量充足" else paste0("样本量不足, 需要至少 ", min_required)
|
||
))
|
||
}
|
||
|
||
# 执行护栏链(按 check_order 顺序执行)
|
||
run_guardrail_chain <- function(guardrail_results) {
|
||
warnings <- c()
|
||
|
||
for (result in guardrail_results) {
|
||
if (!result$passed) {
|
||
if (result$action == ACTION_BLOCK) {
|
||
return(list(
|
||
status = "blocked",
|
||
reason = result$reason
|
||
))
|
||
} else if (result$action == ACTION_SWITCH) {
|
||
return(list(
|
||
status = "switch",
|
||
target_tool = result$action_target,
|
||
reason = result$reason
|
||
))
|
||
} else if (result$action == ACTION_WARN) {
|
||
warnings <- c(warnings, result$reason)
|
||
}
|
||
}
|
||
}
|
||
|
||
return(list(
|
||
status = "passed",
|
||
warnings = warnings
|
||
))
|
||
}
|
||
|
||
# ========== JIT 护栏接口(Phase 2A) ==========
|
||
# 用于 WorkflowExecutor 在执行核心工具前调用
|
||
|
||
#' JIT 护栏检查:执行核心统计前检验假设
|
||
#' @param df 数据框
|
||
#' @param tool_code 目标工具代码
|
||
#' @param params 工具参数(group_var, value_var 等)
|
||
#' @return list(checks, suggested_tool, can_proceed)
|
||
run_jit_guardrails <- function(df, tool_code, params) {
|
||
checks <- list()
|
||
suggested_tool <- tool_code
|
||
can_proceed <- TRUE
|
||
|
||
# 根据工具类型执行不同的检验
|
||
if (tool_code %in% c("ST_T_TEST_IND", "ST_MANN_WHITNEY")) {
|
||
# 独立样本比较:需要正态性 + 方差齐性检验
|
||
group_var <- params$group_var
|
||
value_var <- params$value_var
|
||
|
||
if (!is.null(group_var) && !is.null(value_var)) {
|
||
groups <- unique(df[[group_var]])
|
||
|
||
# 正态性检验(分组)
|
||
for (g in groups) {
|
||
vals <- df[df[[group_var]] == g, value_var]
|
||
if (length(vals) >= 3) {
|
||
norm_result <- check_normality(vals, alpha = 0.05)
|
||
checks <- c(checks, list(list(
|
||
check_name = glue("正态性检验 (组: {g})"),
|
||
passed = norm_result$passed,
|
||
p_value = norm_result$p_value,
|
||
recommendation = if (norm_result$passed) "满足正态性" else "建议使用非参数方法"
|
||
)))
|
||
|
||
if (!norm_result$passed && tool_code == "ST_T_TEST_IND") {
|
||
suggested_tool <- "ST_MANN_WHITNEY"
|
||
}
|
||
}
|
||
}
|
||
|
||
# 方差齐性检验
|
||
if (length(groups) == 2) {
|
||
tryCatch({
|
||
homo_result <- check_homogeneity(df, group_var, value_var, alpha = 0.05)
|
||
checks <- c(checks, list(list(
|
||
check_name = "方差齐性检验 (Levene)",
|
||
passed = homo_result$passed,
|
||
p_value = homo_result$p_value,
|
||
recommendation = if (homo_result$passed) "方差齐性满足" else "建议使用 Welch 校正"
|
||
)))
|
||
}, error = function(e) {
|
||
message("方差齐性检验失败: ", e$message)
|
||
})
|
||
}
|
||
}
|
||
|
||
} else if (tool_code == "ST_T_TEST_PAIRED") {
|
||
# 配对检验:需要差值正态性检验
|
||
before_var <- params$before_var
|
||
after_var <- params$after_var
|
||
|
||
if (!is.null(before_var) && !is.null(after_var)) {
|
||
diff_vals <- df[[after_var]] - df[[before_var]]
|
||
diff_vals <- diff_vals[!is.na(diff_vals)]
|
||
|
||
if (length(diff_vals) >= 3) {
|
||
norm_result <- check_normality(diff_vals, alpha = 0.05)
|
||
checks <- c(checks, list(list(
|
||
check_name = "差值正态性检验",
|
||
passed = norm_result$passed,
|
||
p_value = norm_result$p_value,
|
||
recommendation = if (norm_result$passed) "差值满足正态性" else "建议使用 Wilcoxon 符号秩检验"
|
||
)))
|
||
|
||
if (!norm_result$passed) {
|
||
suggested_tool <- "Wilcoxon signed-rank test"
|
||
}
|
||
}
|
||
}
|
||
|
||
} else if (tool_code == "ST_CORRELATION") {
|
||
# 相关分析:需要双变量正态性检验
|
||
var_x <- params$var_x
|
||
var_y <- params$var_y
|
||
|
||
if (!is.null(var_x) && !is.null(var_y)) {
|
||
x_vals <- df[[var_x]][!is.na(df[[var_x]])]
|
||
y_vals <- df[[var_y]][!is.na(df[[var_y]])]
|
||
|
||
if (length(x_vals) >= 3) {
|
||
norm_x <- check_normality(x_vals, alpha = 0.05)
|
||
checks <- c(checks, list(list(
|
||
check_name = glue("正态性检验 ({var_x})"),
|
||
passed = norm_x$passed,
|
||
p_value = norm_x$p_value,
|
||
recommendation = if (norm_x$passed) "满足正态性" else "建议使用 Spearman 秩相关"
|
||
)))
|
||
}
|
||
|
||
if (length(y_vals) >= 3) {
|
||
norm_y <- check_normality(y_vals, alpha = 0.05)
|
||
checks <- c(checks, list(list(
|
||
check_name = glue("正态性检验 ({var_y})"),
|
||
passed = norm_y$passed,
|
||
p_value = norm_y$p_value,
|
||
recommendation = if (norm_y$passed) "满足正态性" else "建议使用 Spearman 秩相关"
|
||
)))
|
||
|
||
if (!norm_x$passed || !norm_y$passed) {
|
||
suggested_tool <- "ST_CORRELATION (Spearman)"
|
||
}
|
||
}
|
||
}
|
||
|
||
} else if (tool_code == "ST_ANOVA_ONE") {
|
||
group_var <- params$group_var
|
||
value_var <- params$value_var
|
||
|
||
if (!is.null(group_var) && !is.null(value_var)) {
|
||
groups <- unique(df[[group_var]])
|
||
|
||
for (g in groups) {
|
||
vals <- df[df[[group_var]] == g, value_var]
|
||
vals <- vals[!is.na(vals)]
|
||
if (length(vals) >= 3) {
|
||
norm_result <- check_normality(vals, alpha = 0.05)
|
||
checks <- c(checks, list(list(
|
||
check_name = glue("正态性检验 (组: {g})"),
|
||
passed = norm_result$passed,
|
||
p_value = norm_result$p_value,
|
||
recommendation = if (norm_result$passed) "满足正态性" else "建议使用 Kruskal-Wallis"
|
||
)))
|
||
|
||
if (!norm_result$passed) {
|
||
suggested_tool <- "Kruskal-Wallis (内置于 ST_ANOVA_ONE)"
|
||
}
|
||
}
|
||
}
|
||
|
||
tryCatch({
|
||
homo_result <- check_homogeneity(df, group_var, value_var, alpha = 0.05)
|
||
checks <- c(checks, list(list(
|
||
check_name = "方差齐性检验 (Levene)",
|
||
passed = homo_result$passed,
|
||
p_value = homo_result$p_value,
|
||
recommendation = if (homo_result$passed) "方差齐性满足" else "建议使用 Welch ANOVA"
|
||
)))
|
||
}, error = function(e) {
|
||
message("方差齐性检验失败: ", e$message)
|
||
})
|
||
}
|
||
|
||
} else if (tool_code == "ST_WILCOXON") {
|
||
before_var <- params$before_var
|
||
after_var <- params$after_var
|
||
|
||
if (!is.null(before_var) && !is.null(after_var)) {
|
||
diff_vals <- df[[after_var]] - df[[before_var]]
|
||
diff_vals <- diff_vals[!is.na(diff_vals)]
|
||
|
||
checks <- c(checks, list(list(
|
||
check_name = "配对样本量检查",
|
||
passed = length(diff_vals) >= 5,
|
||
recommendation = if (length(diff_vals) >= 5) "样本量充足" else "配对样本量不足"
|
||
)))
|
||
}
|
||
|
||
} else if (tool_code %in% c("ST_FISHER", "ST_CHI_SQUARE")) {
|
||
var1 <- params$var1
|
||
var2 <- params$var2
|
||
|
||
if (!is.null(var1) && !is.null(var2)) {
|
||
ct <- table(df[[var1]], df[[var2]])
|
||
expected <- tryCatch(chisq.test(ct)$expected, error = function(e) NULL)
|
||
|
||
if (!is.null(expected)) {
|
||
low_pct <- sum(expected < 5) / length(expected)
|
||
checks <- c(checks, list(list(
|
||
check_name = "期望频数检查",
|
||
passed = low_pct <= 0.2,
|
||
recommendation = if (low_pct <= 0.2) "期望频数满足卡方检验条件" else "建议使用 Fisher 精确检验"
|
||
)))
|
||
|
||
if (low_pct > 0.2 && tool_code == "ST_CHI_SQUARE") {
|
||
suggested_tool <- "ST_FISHER"
|
||
}
|
||
}
|
||
}
|
||
|
||
} else if (tool_code == "ST_LINEAR_REG") {
|
||
outcome_var <- params$outcome_var
|
||
predictors <- params$predictors
|
||
|
||
if (!is.null(outcome_var)) {
|
||
vals <- df[[outcome_var]][!is.na(df[[outcome_var]])]
|
||
if (length(vals) >= 3) {
|
||
norm_result <- check_normality(vals, alpha = 0.05)
|
||
checks <- c(checks, list(list(
|
||
check_name = glue("结局变量正态性 ({outcome_var})"),
|
||
passed = norm_result$passed,
|
||
p_value = norm_result$p_value,
|
||
recommendation = if (norm_result$passed) "满足正态性" else "结局变量分布偏态,结果需谨慎解读"
|
||
)))
|
||
}
|
||
}
|
||
}
|
||
|
||
# 汇总
|
||
all_passed <- all(sapply(checks, function(c) c$passed))
|
||
|
||
return(list(
|
||
checks = checks,
|
||
suggested_tool = suggested_tool,
|
||
can_proceed = TRUE, # 即使检验不通过也允许继续,由用户/LLM 决定
|
||
all_checks_passed = all_passed
|
||
))
|
||
}
|